Skip to content

Commit

Permalink
Merge pull request go-gorm#12 from go-gorm/improve-register-table-name
Browse files Browse the repository at this point in the history
Fix Gorm Model config support and PKPGSequence sequence key create
  • Loading branch information
huacnlee authored Feb 22, 2022
2 parents 63d6aa9 + b401902 commit 7db4d6d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
4 changes: 4 additions & 0 deletions primary_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func (s *Sharding) genPostgreSQLSequenceKey(tableName string, index int64) int64
return id
}

func (s *Sharding) createPostgreSQLSequenceKeyIfNotExist(tableName string) error {
return s.DB.Exec(`CREATE SEQUENCE IF NOT EXISTS "` + pgSeqName(tableName) + `" START 1`).Error
}

func pgSeqName(table string) string {
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
}
42 changes: 29 additions & 13 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/bwmarrin/snowflake"
"github.com/longbridgeapp/sqlparser"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)

var (
Expand All @@ -25,6 +24,9 @@ type Sharding struct {
configs map[string]Config
querys sync.Map
snowflakeNodes []*snowflake.Node

_config Config
_tables []interface{}
}

// Config specifies the configuration for sharding.
Expand Down Expand Up @@ -80,20 +82,27 @@ type Config struct {
}

func Register(config Config, tables ...interface{}) *Sharding {
return (&Sharding{}).Register(config, tables...)
return &Sharding{
_config: config,
_tables: tables,
}
}

func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
func (s *Sharding) compile() error {
if s.configs == nil {
s.configs = make(map[string]Config)
}
for _, table := range tables {
for _, table := range s._tables {
if t, ok := table.(string); ok {
s.configs[t] = config
} else if t, ok := table.(schema.Tabler); ok {
s.configs[t.TableName()] = config
s.configs[t] = s._config
} else {
panic("invalid config, use string table name or schema.Tabler")
// stmt := &gorm.Statement{DB: s.DB}
stmt := s.DB.Statement
if err := stmt.Parse(table); err == nil {
s.configs[stmt.Table] = s._config
} else {
return err
}
}
}

Expand All @@ -105,20 +114,27 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
if c.PrimaryKeyGenerator == PKSnowflake {
c.PrimaryKeyGeneratorFn = s.genSnowflakeKey
} else if c.PrimaryKeyGenerator == PKPGSequence {

// Execute SQL to CREATE SEQUENCE for this table if not exist
err := s.createPostgreSQLSequenceKeyIfNotExist(t)
if err != nil {
return err
}

c.PrimaryKeyGeneratorFn = func(index int64) int64 {
return s.genPostgreSQLSequenceKey(t, index)
}
} else if c.PrimaryKeyGenerator == PKCustom {
if c.PrimaryKeyGeneratorFn == nil {
panic("PrimaryKeyGeneratorFn not configured")
return errors.New("PrimaryKeyGeneratorFn is required when use PKCustom")
}
} else {
panic("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence and PKCustom")
return errors.New("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence and PKCustom")
}

if c.ShardingAlgorithm == nil {
if c.NumberOfShards == 0 {
panic("specify NumberOfShards or ShardingAlgorithm")
return errors.New("specify NumberOfShards or ShardingAlgorithm")
}
if c.NumberOfShards < 10 {
c.tableFormat = "_%01d"
Expand Down Expand Up @@ -159,7 +175,7 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
s.configs[t] = c
}

return s
return nil
}

// Name plugin name for Gorm plugin interface
Expand Down Expand Up @@ -199,7 +215,7 @@ func (s *Sharding) Initialize(db *gorm.DB) error {
s.snowflakeNodes[i] = n
}

return nil
return s.compile()
}

// resolve split the old query to full table query and sharding table query
Expand Down
11 changes: 3 additions & 8 deletions sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,11 @@ type Order struct {
Product string
}

func (Order) TableName() string {
return "orders"
}

type Category struct {
ID int64 `gorm:"primarykey"`
Name string
}

func (Category) TableName() string {
return "categories"
}

func databaseURL() string {
databaseURL := os.Getenv("DATABASE_URL")
if len(databaseURL) == 0 {
Expand Down Expand Up @@ -74,7 +66,9 @@ func init() {
})
}

fmt.Println("Clean only tables ...")
dropTables()
fmt.Println("AutoMigrate tables ...")
err := db.AutoMigrate(&Order{}, &Category{})
if err != nil {
panic(err)
Expand All @@ -95,6 +89,7 @@ func dropTables() {
tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
for _, table := range tables {
db.Exec("DROP TABLE IF EXISTS " + table)
db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
}
}

Expand Down

0 comments on commit 7db4d6d

Please sign in to comment.