Skip to content

Commit

Permalink
Simplify config
Browse files Browse the repository at this point in the history
  • Loading branch information
hyperphoton committed Jan 25, 2022
1 parent 5ecc36c commit ace1265
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 69 deletions.
24 changes: 6 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,13 @@ Config the sharding middleware, register the tables which you want to shard. See

```go
db.Use(sharding.Register(sharding.Config{
ShardingKey: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
if user_id, ok := value.(int64); ok {
return fmt.Sprintf("_%02d", user_id%64), nil
}
return "", errors.New("invalid user_id")
},
PrimaryKeyGenerator: sharding.PKSnowflake,
ShardingKey: "user_id",
ShardingNumber: 64,
PrimaryKeyGenerator: PKSnowflake,
}, "orders").Register(sharding.Config{
ShardingKey: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
if user_id, ok := value.(int64); ok {
return fmt.Sprintf("_%02d", user_id%256), nil
}
return "", errors.New("invalid user_id")
},
PrimaryKeyGenerate: func(tableIdx int64) int64 {
return snowflake_node.Generate().Int64()
}
ShardingKey: "user_id",
ShardingNumber: 256,
PrimaryKeyGenerator: PKSnowflake,
// This case for show up give notifications, audit_logs table use same sharding rule.
}, Notification{}, AuditLog{}))
```
Expand Down
10 changes: 2 additions & 8 deletions examples/order.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"errors"
"fmt"

"gorm.io/driver/postgres"
Expand Down Expand Up @@ -33,13 +32,8 @@ func main() {
}

middleware := sharding.Register(sharding.Config{
ShardingKey: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
if uid, ok := value.(int64); ok {
return fmt.Sprintf("_%02d", uid%64), nil
}
return "", errors.New("invalid user_id")
},
ShardingKey: "user_id",
ShardingNumber: 64,
PrimaryKeyGenerator: sharding.PKSnowflake,
}, "orders")
db.Use(middleware)
Expand Down
48 changes: 48 additions & 0 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sharding
import (
"errors"
"fmt"
"hash/crc32"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -41,6 +42,12 @@ type Config struct {
// For example, for a product order table, you may want to split the rows by `user_id`.
ShardingKey string

// ShardingNumber specifies how many tables you want to sharding.
ShardingNumber uint

// TableFormat specifies the sharding table suffix format.
TableFormat string

// ShardingAlgorithm specifies a function to generate the sharding
// table's suffix by the column value.
// For example, this function implements a mod sharding algorithm.
Expand Down Expand Up @@ -118,6 +125,47 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
} else {
panic("PrimaryKeyGenerator can only be one of PKSnowflake, PKPGSequence and PKCustom")
}

if c.ShardingAlgorithm == nil {
if c.ShardingNumber == 0 {
panic("specify ShardingNumber or ShardingAlgorithm")
}
if c.ShardingNumber < 10 {
c.TableFormat = "_%01d"
} else if c.ShardingNumber < 100 {
c.TableFormat = "_%02d"
} else if c.ShardingNumber < 1000 {
c.TableFormat = "_%03d"
} else if c.ShardingNumber < 10000 {
c.TableFormat = "_%04d"
}
c.ShardingAlgorithm = func(value interface{}) (suffix string, err error) {
id := 0
switch value := value.(type) {
case int:
id = value
case int64:
id = int(value)
case string:
id, err = strconv.Atoi(value)
if err != nil {
id = int(crc32.ChecksumIEEE([]byte(value)))
}
default:
return "", fmt.Errorf("default algorithm only support integer and string column," +
"if you use other type, specify you own ShardingAlgorithm")
}
return fmt.Sprintf(c.TableFormat, id%int(c.ShardingNumber)), nil
}
}

if c.ShardingAlgorithmByPrimaryKey == nil {
if c.PrimaryKeyGenerator == PKSnowflake {
c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
return fmt.Sprintf(c.TableFormat, snowflake.ParseInt64(id).Node())
}
}
}
s.configs[t] = c
}

Expand Down
64 changes: 21 additions & 43 deletions sharding_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package sharding

import (
"fmt"
"os"
"strconv"
"testing"

"github.com/bwmarrin/snowflake"
Expand Down Expand Up @@ -50,34 +48,14 @@ var (
})

shardingConfig = Config{
DoubleWrite: true,
ShardingKey: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
userId := 0
switch value := value.(type) {
case int:
userId = value
case int64:
userId = int(value)
case string:
userId, err = strconv.Atoi(value)
if err != nil {
return "", err
}
default:
return "", err
}
return fmt.Sprintf("_%02d", userId%4), nil
},
ShardingAlgorithmByPrimaryKey: func(id int64) (suffix string) {
return fmt.Sprintf("_%02d", snowflake.ParseInt64(id).Node())
},
DoubleWrite: true,
ShardingKey: "user_id",
ShardingNumber: 4,
PrimaryKeyGenerator: PKSnowflake,
}

middleware = Register(shardingConfig, &Order{})

node, _ = snowflake.NewNode(1)
node, _ = snowflake.NewNode(1)
)

func init() {
Expand All @@ -86,7 +64,7 @@ func init() {
if err != nil {
panic(err)
}
stables := []string{"orders_00", "orders_01", "orders_02", "orders_03"}
stables := []string{"orders_0", "orders_1", "orders_2", "orders_3"}
for _, table := range stables {
db.Exec(`CREATE TABLE ` + table + ` (
id BIGSERIAL PRIMARY KEY,
Expand All @@ -99,66 +77,66 @@ func init() {
}

func dropTables() {
tables := []string{"orders", "orders_00", "orders_01", "orders_02", "orders_03", "categories"}
tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
for _, table := range tables {
db.Exec("DROP TABLE IF EXISTS " + table)
}
}

func TestInsert(t *testing.T) {
tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"})
assertQueryResult(t, `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx)
assertQueryResult(t, `INSERT INTO "orders_0" ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx)
}

func TestFillID(t *testing.T) {
db.Create(&Order{UserID: 100, Product: "iPhone"})
lastQuery := middleware.LastQuery()
assert.Equal(t, `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES`, lastQuery[0:59])
assert.Equal(t, `INSERT INTO "orders_0" ("user_id", "product", "id") VALUES`, lastQuery[0:58])
}

func TestSelect1(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id", 101).Where("id", node.Generate().Int64()).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = $1 AND "id" = $2`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "user_id" = $1 AND "id" = $2`, tx)
}

func TestSelect2(t *testing.T) {
tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Where("user_id", 101).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" = $1 AND "user_id" = $2`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "id" = $1 AND "user_id" = $2`, tx)
}

func TestSelect3(t *testing.T) {
tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Where("user_id = 101").Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" = $1 AND "user_id" = 101`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "id" = $1 AND "user_id" = 101`, tx)
}

func TestSelect4(t *testing.T) {
tx := db.Model(&Order{}).Where("product", "iPad").Where("user_id", 100).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_00" WHERE "product" = $1 AND "user_id" = $2`, tx)
assertQueryResult(t, `SELECT * FROM "orders_0" WHERE "product" = $1 AND "user_id" = $2`, tx)
}

func TestSelect5(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id = 101").Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = 101`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "user_id" = 101`, tx)
}

func TestSelect6(t *testing.T) {
tx := db.Model(&Order{}).Where("id", node.Generate().Int64()).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" = $1`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "id" = $1`, tx)
}

func TestSelect7(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id", 101).Where("id > ?", node.Generate().Int64()).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = $1 AND "id" > $2`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "user_id" = $1 AND "id" > $2`, tx)
}

func TestSelect8(t *testing.T) {
tx := db.Model(&Order{}).Where("id > ?", node.Generate().Int64()).Where("user_id", 101).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" > $1 AND "user_id" = $2`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "id" > $1 AND "user_id" = $2`, tx)
}

func TestSelect9(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id = 101").First(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = 101 ORDER BY "orders_01"."id" LIMIT 1`, tx)
assertQueryResult(t, `SELECT * FROM "orders_1" WHERE "user_id" = 101 ORDER BY "orders_1"."id" LIMIT 1`, tx)
}

func TestSelect10(t *testing.T) {
Expand All @@ -183,12 +161,12 @@ func TestSelect13(t *testing.T) {

func TestUpdate(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id = ?", 100).Update("product", "new title")
assertQueryResult(t, `UPDATE "orders_00" SET "product" = $1 WHERE "user_id" = $2`, tx)
assertQueryResult(t, `UPDATE "orders_0" SET "product" = $1 WHERE "user_id" = $2`, tx)
}

func TestDelete(t *testing.T) {
tx := db.Where("user_id = ?", 100).Delete(&Order{})
assertQueryResult(t, `DELETE FROM "orders_00" WHERE "user_id" = $1`, tx)
assertQueryResult(t, `DELETE FROM "orders_0" WHERE "user_id" = $1`, tx)
}

func TestInsertMissingShardingKey(t *testing.T) {
Expand Down Expand Up @@ -241,7 +219,7 @@ func TestPKSnowflake(t *testing.T) {
db.Use(middleware)

db.Create(&Order{UserID: 100, Product: "iPhone"})
expected := `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES ($1, $2, 14858`
expected := `INSERT INTO "orders_0" ("user_id", "product", "id") VALUES ($1, $2, 148`
assert.Equal(t, expected, middleware.LastQuery()[0:len(expected)])
}

Expand All @@ -255,7 +233,7 @@ func TestPKPGSequence(t *testing.T) {

db.Exec("SELECT setval('gorm_sharding_serial_for_orders', 42)")
db.Create(&Order{UserID: 100, Product: "iPhone"})
expected := `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES ($1, $2, 43) RETURNING "id"`
expected := `INSERT INTO "orders_0" ("user_id", "product", "id") VALUES ($1, $2, 43) RETURNING "id"`
assert.Equal(t, expected, middleware.LastQuery())
}

Expand Down

0 comments on commit ace1265

Please sign in to comment.