Skip to content

Commit

Permalink
fix:data race(go-gorm#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
tr1v3r committed Aug 29, 2022
1 parent 587ed13 commit e0c4067
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
12 changes: 6 additions & 6 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func (g *Generator) generateQueryFile() (err error) {
pool.Wait()
go func(info *genInfo) {
defer pool.Done()
err = g.generateSingleQueryFile(info)
err := g.generateSingleQueryFile(info)
if err != nil {
errChan <- err
}
Expand All @@ -310,7 +310,7 @@ func (g *Generator) generateQueryFile() (err error) {
var buf bytes.Buffer
err = render(tmpl.Header, &buf, map[string]interface{}{
"Package": g.queryPkgName,
"ImportPkgPaths": importList.Add(g.importPkgPaths...).Output(),
"ImportPkgPaths": importList.Add(g.importPkgPaths...).Paths(),
})
if err != nil {
return err
Expand Down Expand Up @@ -339,7 +339,7 @@ func (g *Generator) generateQueryFile() (err error) {

err = render(tmpl.Header, &buf, map[string]interface{}{
"Package": g.queryPkgName,
"ImportPkgPaths": unitTestImportList.Add(g.importPkgPaths...).Output(),
"ImportPkgPaths": unitTestImportList.Add(g.importPkgPaths...).Paths(),
})
if err != nil {
g.db.Logger.Error(context.Background(), "generate query unit test fail: %s", err)
Expand Down Expand Up @@ -376,7 +376,7 @@ func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) {
}
err = render(tmpl.Header, &buf, map[string]interface{}{
"Package": g.queryPkgName,
"ImportPkgPaths": importList.Add(structPkgPath).Add(getImportPkgPaths(data)...).Output(),
"ImportPkgPaths": importList.Add(structPkgPath).Add(getImportPkgPaths(data)...).Paths(),
})
if err != nil {
return err
Expand Down Expand Up @@ -426,7 +426,7 @@ func (g *Generator) generateQueryUnitTestFile(data *genInfo) (err error) {
}
err = render(tmpl.Header, &buf, map[string]interface{}{
"Package": g.queryPkgName,
"ImportPkgPaths": unitTestImportList.Add(structPkgPath).Add(data.ImportPkgPaths...).Output(),
"ImportPkgPaths": unitTestImportList.Add(structPkgPath).Add(data.ImportPkgPaths...).Paths(),
})
if err != nil {
return err
Expand Down Expand Up @@ -474,7 +474,7 @@ func (g *Generator) generateModelFile() error {
defer pool.Done()

var buf bytes.Buffer
err = render(tmpl.Model, &buf, data)
err := render(tmpl.Model, &buf, data)
if err != nil {
errChan <- err
return
Expand Down
22 changes: 15 additions & 7 deletions import.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package gen
import "strings"

var (
importList = importPkgS{}.Add(
importList = new(importPkgS).Add(
"context",
"database/sql",
"strings",
Expand All @@ -18,7 +18,7 @@ var (
"",
"gorm.io/plugin/dbresolver",
)
unitTestImportList = importPkgS{}.Add(
unitTestImportList = new(importPkgS).Add(
"context",
"fmt",
"strconv",
Expand All @@ -29,18 +29,23 @@ var (
)
)

type importPkgS struct{ paths []string }
type importPkgS struct {
paths []string
}

func (ip importPkgS) Add(paths ...string) *importPkgS {
purePaths := make([]string, 0, len(paths)+1)
for _, p := range paths {
p = strings.TrimSpace(p)
if p == "" {
ip.paths = append(ip.paths, p)
purePaths = append(purePaths, p)
continue
}

if p[len(p)-1] != '"' {
p = `"` + p + `"`
}

var exists bool
for _, existsP := range ip.paths {
if p == existsP {
Expand All @@ -49,11 +54,14 @@ func (ip importPkgS) Add(paths ...string) *importPkgS {
}
}
if !exists {
ip.paths = append(ip.paths, p)
purePaths = append(purePaths, p)
}
}
ip.paths = append(ip.paths, "")
purePaths = append(purePaths, "")

ip.paths = append(ip.paths, purePaths...)

return &ip
}

func (ip *importPkgS) Output() []string { return ip.paths }
func (ip importPkgS) Paths() []string { return ip.paths }

0 comments on commit e0c4067

Please sign in to comment.