Skip to content

Commit

Permalink
Readwrite-splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
hyperphoton committed Mar 15, 2022
1 parent cde328a commit 760b474
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 21 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Gorm Sharding 是一个高性能的数据库分表中间件。
- Lighting-fast. No network based middlewares, as fast as Go.
- Multiple database (PostgreSQL, MySQL) support.
- Integrated primary key generator (Snowflake, PostgreSQL Sequence, Custom, ...).
- Readwrite-splitting.

## Install

Expand Down Expand Up @@ -94,6 +95,26 @@ Recommend options:
- [Snowflake](https://github.com/bwmarrin/snowflake)
- [Database sequence by manully](https://www.postgresql.org/docs/current/sql-createsequence.html)
## Readwrite-splitting
```go
dsnRead := "host=localhost user=gorm password=gorm dbname=gorm-slave port=5432 sslmode=disable"
dsnWrite := "host=localhost user=gorm password=gorm dbname=gorm port=5432 sslmode=disable"

connRead := postgres.Open(dsnRead)
connWrite := postgres.Open(dsnWrite)

db, err := gorm.Open(connWrite, &gorm.Config{})

db.Use(sharding.Register(sharding.Config{
ShardingKey: "user_id",
NumberOfShards: 64,
PrimaryKeyGenerator: sharding.PKSnowflake,
ReadConnections: []gorm.Dialector{connRead},
WriteConnections: []gorm.Dialector{connWrite},
}
```
## Sharding process
This graph show up how Gorm Sharding works.
Expand Down
40 changes: 34 additions & 6 deletions conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sharding
import (
"context"
"database/sql"
"math/rand"

"gorm.io/gorm"
)
Expand Down Expand Up @@ -36,7 +37,7 @@ func (pool ConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stm
}

func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
ftQuery, stQuery, table, err := pool.sharding.resolve(query, args...)
ftQuery, stQuery, table, stmtType, err := pool.sharding.resolve(query, args...)
if err != nil {
return nil, err
}
Expand All @@ -51,12 +52,14 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...inte
}
}

return pool.ConnPool.ExecContext(ctx, stQuery, args...)
cp := pool.GetReadWriteConn(table, stmtType)

return cp.ExecContext(ctx, stQuery, args...)
}

// https://github.com/go-gorm/gorm/blob/v1.21.11/callbacks/query.go#L18
func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
ftQuery, stQuery, table, err := pool.sharding.resolve(query, args...)
ftQuery, stQuery, table, stmtType, err := pool.sharding.resolve(query, args...)
if err != nil {
return nil, err
}
Expand All @@ -71,14 +74,18 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int
}
}

return pool.ConnPool.QueryContext(ctx, stQuery, args...)
cp := pool.GetReadWriteConn(table, stmtType)

return cp.QueryContext(ctx, stQuery, args...)
}

func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
_, query, _, _ = pool.sharding.resolve(query, args...)
_, query, table, stmtType, _ := pool.sharding.resolve(query, args...)
pool.sharding.querys.Store("last_query", query)

return pool.ConnPool.QueryRowContext(ctx, query, args...)
cp := pool.GetReadWriteConn(table, stmtType)

return cp.QueryRowContext(ctx, query, args...)
}

// BeginTx Implement ConnPoolBeginner.BeginTx
Expand Down Expand Up @@ -111,3 +118,24 @@ func (pool *ConnPool) Rollback() error {
func (pool *ConnPool) Ping() error {
return nil
}

func (pool *ConnPool) GetReadWriteConn(table, stmtType string) gorm.ConnPool {
cp := pool.ConnPool
if table != "" {
switch stmtType {
case "SELECT":
if conns, ok := pool.sharding.readConns[table]; ok {
if len(conns) > 0 {
cp = conns[rand.Intn(len(conns))]
}
}
case "INSERT", "UPDATE", "DELETE":
if conns, ok := pool.sharding.writeConns[table]; ok {
if len(conns) > 0 {
cp = conns[rand.Intn(len(conns))]
}
}
}
}
return cp
}
48 changes: 42 additions & 6 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ type Sharding struct {
querys sync.Map
snowflakeNodes []*snowflake.Node

readConns map[string][]gorm.ConnPool
writeConns map[string][]gorm.ConnPool

_config Config
_tables []interface{}
}
Expand Down Expand Up @@ -79,6 +82,12 @@ type Config struct {
// return nodes[tableIdx].Generate().Int64()
// }
PrimaryKeyGeneratorFn func(tableIdx int64) int64

// ReadConnections specifies the connections for read, like SELECT.
ReadConnections []gorm.Dialector

// WriteConnections specifies the connections for wite, like CREATE, UPDATE, DELETE.
WriteConnections []gorm.Dialector
}

func Register(config Config, tables ...interface{}) *Sharding {
Expand Down Expand Up @@ -195,6 +204,30 @@ func (s *Sharding) LastQuery() string {
// Initialize implement for Gorm plugin interface
func (s *Sharding) Initialize(db *gorm.DB) error {
s.DB = db
err := s.compile()
if err != nil {
return err
}

s.readConns = make(map[string][]gorm.ConnPool)
s.writeConns = make(map[string][]gorm.ConnPool)
for t, c := range s.configs {
for _, dialector := range c.ReadConnections {
db, err := gorm.Open(dialector, s.DB.Config)
if err != nil {
return err
}
s.readConns[t] = append(s.readConns[t], db.Config.ConnPool)
}
for _, dialector := range c.WriteConnections {
db, err := gorm.Open(dialector, s.DB.Config)
if err != nil {
return err
}
s.writeConns[t] = append(s.writeConns[t], db.Config.ConnPool)
}
}

s.registerConnPool(db)

for t, c := range s.configs {
Expand All @@ -215,11 +248,11 @@ func (s *Sharding) Initialize(db *gorm.DB) error {
s.snowflakeNodes[i] = n
}

return s.compile()
return nil
}

// resolve split the old query to full table query and sharding table query
func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, tableName string, err error) {
func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, tableName, stmtType string, err error) {
ftQuery = query
stQuery = query
if len(s.configs) == 0 {
Expand All @@ -228,7 +261,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,

expr, err := sqlparser.NewParser(strings.NewReader(query)).ParseStatement()
if err != nil {
return ftQuery, stQuery, tableName, nil
return ftQuery, stQuery, tableName, stmtType, nil
}

var table *sqlparser.TableName
Expand All @@ -248,20 +281,23 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
}
table = tbl
condition = stmt.Condition

stmtType = "SELECT"
case *sqlparser.InsertStatement:
table = stmt.TableName
isInsert = true
insertNames = stmt.ColumnNames
insertValues = stmt.Expressions[0].Exprs
stmtType = "INSERT"
case *sqlparser.UpdateStatement:
condition = stmt.Condition
table = stmt.TableName
stmtType = "UPDATE"
case *sqlparser.DeleteStatement:
condition = stmt.Condition
table = stmt.TableName
stmtType = "DELETE"
default:
return ftQuery, stQuery, "", sqlparser.ErrNotImplemented
return ftQuery, stQuery, "", "", sqlparser.ErrNotImplemented
}

tableName = table.Name.Name
Expand Down Expand Up @@ -313,7 +349,7 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
if fillID {
tblIdx, err := strconv.Atoi(strings.Replace(suffix, "_", "", 1))
if err != nil {
return ftQuery, stQuery, tableName, err
return ftQuery, stQuery, tableName, "", err
}
id := r.PrimaryKeyGeneratorFn(int64(tblIdx))
insertNames = append(insertNames, &sqlparser.Ident{Name: "id"})
Expand Down
93 changes: 84 additions & 9 deletions sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,35 +37,82 @@ func databaseURL() string {
return databaseURL
}

func databaseReadURL() string {
databaseURL := os.Getenv("DATABASE_READ_URL")
if len(databaseURL) == 0 {
databaseURL = "postgres://localhost:5432/sharding-read-test?sslmode=disable"
if os.Getenv("DIALECTOR") == "mysql" {
databaseURL = "root@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4"
}
}
return databaseURL
}

func databaseWriteURL() string {
databaseURL := os.Getenv("DATABASE_WRITE_URL")
if len(databaseURL) == 0 {
databaseURL = "postgres://localhost:5432/sharding-write-test?sslmode=disable"
if os.Getenv("DIALECTOR") == "mysql" {
databaseURL = "root@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4"
}
}
return databaseURL
}

var (
dbConfig = postgres.Config{
DSN: databaseURL(),
PreferSimpleProtocol: true,
}
db *gorm.DB

shardingConfig = Config{
DoubleWrite: true,
ShardingKey: "user_id",
NumberOfShards: 4,
PrimaryKeyGenerator: PKSnowflake,
dbReadConfig = postgres.Config{
DSN: databaseReadURL(),
PreferSimpleProtocol: true,
}
dbWriteConfig = postgres.Config{
DSN: databaseWriteURL(),
PreferSimpleProtocol: true,
}
db, dbRead, dbWrite *gorm.DB

middleware = Register(shardingConfig, &Order{})
node, _ = snowflake.NewNode(1)
shardingConfig Config
middleware *Sharding
node, _ = snowflake.NewNode(1)
)

func init() {
if os.Getenv("DIALECTOR") == "mysql" {
db, _ = gorm.Open(mysql.Open(databaseURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbRead, _ = gorm.Open(mysql.Open(databaseReadURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbWrite, _ = gorm.Open(mysql.Open(databaseWriteURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
} else {
db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbRead, _ = gorm.Open(postgres.New(dbReadConfig), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbWrite, _ = gorm.Open(postgres.New(dbWriteConfig), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
}

shardingConfig = Config{
DoubleWrite: true,
ShardingKey: "user_id",
NumberOfShards: 4,
PrimaryKeyGenerator: PKSnowflake,
ReadConnections: []gorm.Dialector{dbRead.Dialector},
WriteConnections: []gorm.Dialector{dbWrite.Dialector},
}

middleware = Register(shardingConfig, &Order{})

fmt.Println("Clean only tables ...")
dropTables()
fmt.Println("AutoMigrate tables ...")
Expand All @@ -80,6 +127,16 @@ func init() {
user_id bigint,
product text
)`)
dbRead.Exec(`CREATE TABLE ` + table + ` (
id bigint PRIMARY KEY,
user_id bigint,
product text
)`)
dbWrite.Exec(`CREATE TABLE ` + table + ` (
id bigint PRIMARY KEY,
user_id bigint,
product text
)`)
}

db.Use(middleware)
Expand All @@ -89,6 +146,8 @@ func dropTables() {
tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
for _, table := range tables {
db.Exec("DROP TABLE IF EXISTS " + table)
dbRead.Exec("DROP TABLE IF EXISTS " + table)
dbWrite.Exec("DROP TABLE IF EXISTS " + table)
db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
}
}
Expand Down Expand Up @@ -264,6 +323,22 @@ func TestPKPGSequence(t *testing.T) {
assert.Equal(t, expected, middleware.LastQuery())
}

func TestReadWriteConnections(t *testing.T) {
dbRead.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")
dbWrite.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")

var order Order
db.Model(&Order{}).Where("user_id", 100).Find(&order)
assert.Equal(t, "iPad", order.Product)

db.Model(&Order{}).Where("user_id", 100).Update("product", "iPhone")
db.Table("orders_0").Where("user_id", 100).Find(&order)
assert.Equal(t, "iPad", order.Product)

dbWrite.Table("orders_0").Where("user_id", 100).Find(&order)
assert.Equal(t, "iPhone", order.Product)
}

func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) {
t.Helper()
assert.Equal(t, toDialect(expected), middleware.LastQuery())
Expand Down

0 comments on commit 760b474

Please sign in to comment.