forked from go-gorm/gen
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
2,362 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,321 @@ | ||
package gen | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"io" | ||
"os" | ||
"path/filepath" | ||
"strconv" | ||
"strings" | ||
"text/template" | ||
|
||
"gorm.io/gorm" | ||
|
||
"golang.org/x/tools/imports" | ||
"gorm.io/gen/internal/check" | ||
"gorm.io/gen/internal/parser" | ||
tmpl "gorm.io/gen/internal/template" | ||
"gorm.io/gen/log" | ||
) | ||
|
||
// TODO implement some unit tests | ||
|
||
// T genric type | ||
type T interface{} | ||
|
||
// NewGenerator create a new generator | ||
func NewGenerator(cfg Config) *Generator { | ||
if cfg.modelPkgName == "" { | ||
cfg.modelPkgName = check.ModelPkg | ||
} | ||
return &Generator{ | ||
Config: cfg, | ||
Data: make(map[string]*genInfo), | ||
readInterfaceSet: new(parser.InterfaceSet), | ||
} | ||
} | ||
|
||
// Config generator's basic configuration | ||
type Config struct { | ||
OutPath string | ||
OutFile string | ||
|
||
pkgName string | ||
modelPkgName string //default model | ||
db *gorm.DB //nolint | ||
} | ||
|
||
func (c *Config) SetModelPkg(name string) { | ||
c.modelPkgName = name | ||
} | ||
|
||
func (c *Config) SetPkg(name string) { | ||
c.pkgName = name | ||
} | ||
|
||
func (c *Config) GetPkg() string { | ||
return c.pkgName | ||
} | ||
|
||
// genInfo info about generated code | ||
type genInfo struct { | ||
*check.BaseStruct | ||
Interfaces []*check.InterfaceMethod | ||
} | ||
|
||
// Generator code generator | ||
type Generator struct { | ||
Config | ||
|
||
Data map[string]*genInfo | ||
readInterfaceSet *parser.InterfaceSet | ||
} | ||
|
||
// UseDB set db connection | ||
func (g *Generator) UseDB(db *gorm.DB) { | ||
g.db = db | ||
} | ||
|
||
// Tables collect table model | ||
func (g *Generator) Tables(models ...interface{}) { | ||
structs, err := check.CheckStructs(g.db, models...) | ||
if err != nil { | ||
log.Fatalf("gen struct error: %s", err) | ||
} | ||
for _, interfaceStruct := range structs { | ||
data := g.getData(interfaceStruct.NewStructName) | ||
if data.BaseStruct == nil { | ||
data.BaseStruct = interfaceStruct | ||
} | ||
} | ||
} | ||
|
||
// TableNames collect table names | ||
func (g *Generator) TableNames(names ...string) { | ||
structs, err := check.GenBaseStructs(g.db, g.Config.modelPkgName, names...) | ||
if err != nil { | ||
log.Fatalf("check struct error: %s", err) | ||
} | ||
for _, interfaceStruct := range structs { | ||
data := g.getData(interfaceStruct.NewStructName) | ||
if data.BaseStruct == nil { | ||
data.BaseStruct = interfaceStruct | ||
} | ||
} | ||
} | ||
|
||
// Apply specifies method interfaces on structures, implment codes will be generated after calling g.Execute() | ||
// eg: g.Apply(func(model.Method){}, model.User{}, model.Company{}) | ||
func (g *Generator) Apply(fc interface{}, models ...interface{}) { | ||
var err error | ||
|
||
structs, err := check.CheckStructs(g.db, models...) | ||
if err != nil { | ||
log.Fatalf("check struct error: %s", err) | ||
} | ||
g.apply(fc, structs) | ||
} | ||
|
||
func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) { | ||
interfacePaths, err := parser.GetInterfacePath(fc) | ||
if err != nil { | ||
log.Fatalf("can't get interface name or file: %s", err) | ||
} | ||
|
||
err = g.readInterfaceSet.ParseFile(interfacePaths) | ||
if err != nil { | ||
log.Fatalf("parser file error: %s", err) | ||
} | ||
|
||
for _, interfaceStruct := range structs { | ||
data := g.getData(interfaceStruct.NewStructName) | ||
if data.BaseStruct == nil { | ||
data.BaseStruct = interfaceStruct | ||
} | ||
|
||
functions, err := check.CheckInterface(g.readInterfaceSet, interfaceStruct) | ||
if err != nil { | ||
log.Fatalf("check interface error: %s", err) | ||
} | ||
|
||
for _, function := range functions { | ||
data.Interfaces = function.DupAppend(data.Interfaces) | ||
} | ||
} | ||
} | ||
|
||
// ApplyByModel specifies one method interface on several model structures | ||
// eg: g.ApplyByModel(model.User{}, func(model.Method1, model.Method2){}) | ||
func (g *Generator) ApplyByModel(model interface{}, fc interface{}) { | ||
g.Apply(fc, model) | ||
} | ||
|
||
// ApplyByTable specifies table by table names | ||
// eg: g.ApplyByTable(func(model.Model){}, "user", "role") | ||
func (g *Generator) ApplyByTable(fc interface{}, tableNames ...string) { | ||
structs, err := check.GenBaseStructs(g.db, g.Config.modelPkgName, tableNames...) | ||
if err != nil { | ||
log.Fatalf("gen struct error: %s", err) | ||
} | ||
g.apply(fc, structs) | ||
} | ||
|
||
// Execute generate code to output path | ||
func (g *Generator) Execute() { | ||
var err error | ||
if g.OutPath == "" { | ||
g.OutPath = "./query" | ||
} | ||
if g.OutFile == "" { | ||
g.OutFile = g.OutPath + "/gorm_generated.go" | ||
} | ||
if _, err := os.Stat(g.OutPath); err != nil { | ||
if err := os.Mkdir(g.OutPath, os.ModePerm); err != nil { | ||
log.Fatalf("mkdir failed: %s", err) | ||
} | ||
} | ||
g.SetPkg(filepath.Base(g.OutPath)) | ||
|
||
err = g.generatedBaseStruct() | ||
if err != nil { | ||
log.Fatalf("generate base struct fail: %s", err) | ||
} | ||
err = g.generatedToOutFile() | ||
if err != nil { | ||
log.Fatalf("generate to file fail: %s", err) | ||
} | ||
log.Println("Gorm generated interface file successful!") | ||
log.Println("Generated path:", g.OutPath) | ||
log.Println("Generated file:", g.OutFile) | ||
} | ||
|
||
// generatedToOutFile save generate code to file | ||
func (g *Generator) generatedToOutFile() (err error) { | ||
var buf bytes.Buffer | ||
|
||
render := func(tmpl string, wr io.Writer, data interface{}) error { | ||
t, err := template.New(tmpl).Parse(tmpl) | ||
if err != nil { | ||
return err | ||
} | ||
return t.Execute(wr, data) | ||
} | ||
|
||
err = render(tmpl.HeaderTmpl, &buf, g.GetPkg()) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
for _, data := range g.Data { | ||
err = render(tmpl.BaseStruct, &buf, data.BaseStruct) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
for _, method := range data.Interfaces { | ||
err = render(tmpl.FuncTmpl, &buf, method) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
err = render(tmpl.BaseGormFunc, &buf, data.BaseStruct) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
err = render(tmpl.UseTmpl, &buf, g) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
result, err := imports.Process(g.OutFile, buf.Bytes(), nil) | ||
if err != nil { | ||
errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1]) | ||
line := strings.Split(buf.String(), "\n") | ||
for i := -3; i < 3; i++ { | ||
fmt.Println(i+errLine, line[i+errLine]) | ||
} | ||
return fmt.Errorf("can't format generated file: %w", err) | ||
} | ||
return outputFile(g.OutFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, result) | ||
} | ||
|
||
// generatedBaseStruct generate basic structures | ||
func (g *Generator) generatedBaseStruct() (err error) { | ||
outPath, err := filepath.Abs(g.OutPath) | ||
if err != nil { | ||
return err | ||
} | ||
pkg := g.modelPkgName | ||
if pkg == "" { | ||
pkg = check.ModelPkg | ||
} | ||
outPath = fmt.Sprint(filepath.Dir(outPath), "/", pkg, "/") | ||
if _, err := os.Stat(outPath); err != nil { | ||
if err := os.Mkdir(outPath, os.ModePerm); err != nil { | ||
log.Fatalf("mkdir failed: %s", err) | ||
} | ||
} | ||
for _, data := range g.Data { | ||
if data.BaseStruct == nil || !data.BaseStruct.GenBaseStruct { | ||
continue | ||
} | ||
var buf bytes.Buffer | ||
err = render(tmpl.ModelTemplate, &buf, data.BaseStruct) | ||
if err != nil { | ||
return err | ||
} | ||
modelFile := fmt.Sprint(outPath, data.BaseStruct.TableName, ".go") | ||
result, err := imports.Process(modelFile, buf.Bytes(), nil) | ||
if err != nil { | ||
for i, line := range strings.Split(buf.String(), "\n") { | ||
fmt.Println(i, line) | ||
} | ||
return fmt.Errorf("can't format generated file: %w", err) | ||
} | ||
err = outputFile(modelFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, result) | ||
if err != nil { | ||
return nil | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func (g *Generator) getData(structName string) *genInfo { | ||
if g.Data[structName] == nil { | ||
g.Data[structName] = new(genInfo) | ||
} | ||
return g.Data[structName] | ||
} | ||
|
||
func outputFile(filename string, flag int, data []byte) error { | ||
out, err := os.OpenFile(filename, flag, 0640) | ||
if err != nil { | ||
return fmt.Errorf("can't open out file: %w", err) | ||
} | ||
return output(out, data) | ||
} | ||
|
||
func output(wr io.WriteCloser, data []byte) (err error) { | ||
defer func() { | ||
if e := wr.Close(); e != nil { | ||
err = fmt.Errorf("can't close: %w", e) | ||
} | ||
}() | ||
|
||
if _, err = wr.Write(data); err != nil { | ||
return fmt.Errorf("can't write: %w", err) | ||
} | ||
return nil | ||
} | ||
|
||
func render(tmpl string, wr io.Writer, data interface{}) error { | ||
t, err := template.New(tmpl).Parse(tmpl) | ||
if err != nil { | ||
return err | ||
} | ||
return t.Execute(wr, data) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,38 @@ | ||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= | ||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= | ||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= | ||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= | ||
github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= | ||
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= | ||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= | ||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= | ||
golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= | ||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= | ||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= | ||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE= | ||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= | ||
golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= | ||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= | ||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= | ||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||
gorm.io/driver/mysql v1.1.1 h1:yr1bpyqiwuSPJ4aGGUX9nu46RHXlF8RASQVb1QQNcvo= | ||
gorm.io/driver/mysql v1.1.1/go.mod h1:KdrTanmfLPPyAOeYGyG+UpDys7/7eeWT1zCq+oekYnU= | ||
gorm.io/gorm v1.21.9/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= | ||
gorm.io/gorm v1.21.12 h1:3fQM0Eiz7jcJEhPggHEpoYnsGZqynMzverL77DV40RM= | ||
gorm.io/gorm v1.21.12/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= |
Oops, something went wrong.