Skip to content

Commit

Permalink
Perform schema checks for multiple SQL database and add context to Ad…
Browse files Browse the repository at this point in the history
…minDB DDL interface (cadence-workflow#4561)
  • Loading branch information
longquanzheng authored Oct 26, 2021
1 parent 9a072ca commit 085a799
Show file tree
Hide file tree
Showing 17 changed files with 419 additions and 121 deletions.
3 changes: 2 additions & 1 deletion common/persistence/sql/sqlPersistenceTest.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package sql

import (
"context"
"fmt"
"io/ioutil"
"os"
Expand Down Expand Up @@ -199,7 +200,7 @@ func (s *testCluster) loadDatabaseSchema(dir string, fileNames []string, overrid
if err != nil {
return fmt.Errorf("error reading contents of file %v:%v", file, err.Error())
}
err = db.Exec(string(content))
err = db.ExecSchemaOperationQuery(context.Background(), string(content))
if err != nil {
return fmt.Errorf("error loading schema from %v: %v", file, err.Error())
}
Expand Down
23 changes: 14 additions & 9 deletions common/persistence/sql/sqldriver/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,25 @@ type (
//The layer is added so that we can have a adapter to support multiple SQL databases behind a single Cadence cluster
Driver interface {

// shared methods of both non-transactional sqlx.DB and transactional sqlx.Tx
// shared methods are for both non-transactional (using sqlx.DB) and transactional (using sqlx.Tx) operation --
// if a transaction is started(using BeginTxx), then query are executed in the transaction mode. Otherwise executed in normal mode.
commonOfDbAndTx

// From sqlx.DB: those methods are executed without starting a transaction
// TODO: maybe rename to make it more clear
Exec(dbShardID int, query string, args ...interface{}) (sql.Result, error)
Select(dbShardID int, dest interface{}, query string, args ...interface{}) error
Get(dbShardID int, dest interface{}, query string, args ...interface{}) error
// BeginTxx starts a new transaction in the shard of dbShardID
BeginTxx(ctx context.Context, dbShardID int, opts *sql.TxOptions) (*sqlx.Tx, error)
Close() error

// sqlx.Tx
// Commit commits the current transaction(started by BeginTxx)
Commit() error
// Rollback rollbacks the current transaction(started by BeginTxx)
Rollback() error
// Close closes this driver(and underlying connections)
Close() error

// ExecDDL executes a DDL query
ExecDDL(ctx context.Context, dbShardID int, query string, args ...interface{}) (sql.Result, error)
// SelectForSchemaQuery executes a select query for schema(returning multiple rows).
SelectForSchemaQuery(dbShardID int, dest interface{}, query string, args ...interface{}) error
// GetForSchemaQuery executes a get query for schema(returning single row).
GetForSchemaQuery(dbShardID int, dest interface{}, query string, args ...interface{}) error
}

// the methods can be executed from either a started or transaction(then need to call Commit/Rollback), or without a transaction
Expand Down
54 changes: 12 additions & 42 deletions common/persistence/sql/sqldriver/sharded.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,40 +115,22 @@ func (s *sharded) SelectContext(ctx context.Context, dbShardID int, dest interfa

// below are non-transactional methods only

func (s *sharded) Exec(dbShardID int, query string, args ...interface{}) (sql.Result, error) {
if dbShardID == sqlplugin.DbShardUndefined {
return nil, fmt.Errorf("DbShardUndefined shouldn't be used to Exec, there must be a bug")
}
if dbShardID == sqlplugin.DbAllShards {
// NOTE: this can only be safely used for schema operation
var errs []error
for _, db := range s.dbs {
_, err := db.Exec(query, args...)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
// Note that this will break sqlplugin.ErrorChecker contract, but it's okay for now as DbAllShards are only being used for schema
return nil, multierr.Combine(errs...)
}
return newShardedSqlExecResult(), nil
}
return s.dbs[dbShardID].Exec(query, args...)
func (s *sharded) ExecDDL(ctx context.Context, dbShardID int, query string, args ...interface{}) (sql.Result, error) {
// sharded SQL driver doesn't implement any schema operation as it's hard to guarantee the correctness.
// schema operation across shards is implemented by application layer
return nil, fmt.Errorf("sharded SQL driver shouldn't be used to ExecDDL, there must be a bug")
}

func (s *sharded) Select(dbShardID int, dest interface{}, query string, args ...interface{}) error {
if dbShardID == sqlplugin.DbShardUndefined || dbShardID == sqlplugin.DbAllShards {
return fmt.Errorf("invalid dbShardID %v shouldn't be used to Select, there must be a bug", dbShardID)
}
return s.dbs[dbShardID].Select(dest, query, args...)
func (s *sharded) SelectForSchemaQuery(dbShardID int, dest interface{}, query string, args ...interface{}) error {
// sharded SQL driver doesn't implement any schema operation as it's hard to guarantee the correctness.
// schema operation across shards is implemented by application layer
return fmt.Errorf("sharded SQL driver shouldn't be used to SelectForSchemaQuery, there must be a bug")
}

func (s *sharded) Get(dbShardID int, dest interface{}, query string, args ...interface{}) error {
if dbShardID == sqlplugin.DbShardUndefined || dbShardID == sqlplugin.DbAllShards {
return fmt.Errorf("invalid dbShardID %v shouldn't be used to Get, there must be a bug", dbShardID)
}
return s.dbs[dbShardID].Get(dest, query, args...)
func (s *sharded) GetForSchemaQuery(dbShardID int, dest interface{}, query string, args ...interface{}) error {
// sharded SQL driver doesn't implement any schema operation as it's hard to guarantee the correctness.
// schema operation across shards is implemented by application layer
return fmt.Errorf("sharded SQL driver shouldn't be used to GetForSchemaQuery, there must be a bug")
}

func (s *sharded) BeginTxx(ctx context.Context, dbShardID int, opts *sql.TxOptions) (*sqlx.Tx, error) {
Expand Down Expand Up @@ -182,18 +164,6 @@ func (s *sharded) Rollback() error {
return s.tx.Rollback()
}

func newShardedSqlExecResult() sql.Result {
return &shardedSqlExecResult{}
}

func (s shardedSqlExecResult) LastInsertId() (int64, error) {
return 0, fmt.Errorf("not implemented for sharded SQL driver")
}

func (s shardedSqlExecResult) RowsAffected() (int64, error) {
return 0, fmt.Errorf("not implemented for sharded SQL driver")
}

func getUnmatchedTxnError(requestShardID, startedShardId int) error {
return fmt.Errorf("requested dbShardID %v doesn't match with started transaction shardID %v, must be a bug", requestShardID, startedShardId)
}
8 changes: 4 additions & 4 deletions common/persistence/sql/sqldriver/singleton.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ func (s *singleton) SelectContext(ctx context.Context, _ int, dest interface{},

// below are non-transactional methods only

func (s *singleton) Exec(_ int, query string, args ...interface{}) (sql.Result, error) {
return s.db.Exec(query, args...)
func (s *singleton) ExecDDL(ctx context.Context, _ int, query string, args ...interface{}) (sql.Result, error) {
return s.db.ExecContext(ctx, query, args...)
}

func (s *singleton) Select(_ int, dest interface{}, query string, args ...interface{}) error {
func (s *singleton) SelectForSchemaQuery(_ int, dest interface{}, query string, args ...interface{}) error {
return s.db.Select(dest, query, args...)
}

func (s *singleton) Get(_ int, dest interface{}, query string, args ...interface{}) error {
func (s *singleton) GetForSchemaQuery(_ int, dest interface{}, query string, args ...interface{}) error {
return s.db.Get(dest, query, args...)
}

Expand Down
4 changes: 2 additions & 2 deletions common/persistence/sql/sqlplugin/dbSharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ import (
const (
// this means the query need to execute in one shard but the shard should be fixed/static, e.g. for domain, queue storage are single shard
DbDefaultShard = 0
// this is should never being used to query anything
// this is should never being used in sharded SQL driver. It is used in admin/schema operation in singleton driver, which ignores all the shardID parameter
DbShardUndefined = -1
// this means the query needs to execute in all dbShards, e.g. used for executing queries for schemas,
// this means the query needs to execute in all dbShards in sharded SQL driver (currently not supported)
DbAllShards = -2
)

Expand Down
3 changes: 2 additions & 1 deletion common/persistence/sql/sqlplugin/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,8 @@ type (
DropAllTables(database string) error
CreateDatabase(database string) error
DropDatabase(database string) error
Exec(stmt string, args ...interface{}) error
// ExecSchemaOperationQuery allows passing in any query, but it must be schema operation (DDL)
ExecSchemaOperationQuery(ctx context.Context, stmt string, args ...interface{}) error
}

// Tx defines the API for a SQL transaction
Expand Down
26 changes: 13 additions & 13 deletions common/persistence/sql/sqlplugin/mysql/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package mysql

import (
"context"
"fmt"
"time"

Expand Down Expand Up @@ -61,48 +62,47 @@ const (

// CreateSchemaVersionTables sets up the schema version tables
func (mdb *db) CreateSchemaVersionTables() error {
if err := mdb.Exec(createSchemaVersionTableQuery); err != nil {
if err := mdb.ExecSchemaOperationQuery(context.Background(), createSchemaVersionTableQuery); err != nil {
return err
}
return mdb.Exec(createSchemaUpdateHistoryTableQuery)
return mdb.ExecSchemaOperationQuery(context.Background(), createSchemaUpdateHistoryTableQuery)
}

// ReadSchemaVersion returns the current schema version for the keyspace
func (mdb *db) ReadSchemaVersion(database string) (string, error) {
var version string
err := mdb.driver.Get(sqlplugin.DbDefaultShard, &version, readSchemaVersionQuery, database)
err := mdb.driver.GetForSchemaQuery(sqlplugin.DbShardUndefined, &version, readSchemaVersionQuery, database)
return version, err
}

// UpdateSchemaVersion updates the schema version for the keyspace
func (mdb *db) UpdateSchemaVersion(database string, newVersion string, minCompatibleVersion string) error {
return mdb.Exec(writeSchemaVersionQuery, database, time.Now(), newVersion, minCompatibleVersion)
return mdb.ExecSchemaOperationQuery(context.Background(), writeSchemaVersionQuery, database, time.Now(), newVersion, minCompatibleVersion)
}

// WriteSchemaUpdateLog adds an entry to the schema update history table
func (mdb *db) WriteSchemaUpdateLog(oldVersion string, newVersion string, manifestMD5 string, desc string) error {
now := time.Now().UTC()
return mdb.Exec(writeSchemaUpdateHistoryQuery, now.Year(), int(now.Month()), now, oldVersion, newVersion, manifestMD5, desc)
return mdb.ExecSchemaOperationQuery(context.Background(), writeSchemaUpdateHistoryQuery, now.Year(), int(now.Month()), now, oldVersion, newVersion, manifestMD5, desc)
}

// Exec executes a sql statement
// ExecSchemaOperationQuery executes a sql statement for schema ONLY. DO NOT use it in other cases, otherwise it will not work for multiple SQL database.
// For Sharded SQL, it will execute the statement for all shards
// TODO: rename to ExecSchemaQuery so that we know it should use DB_ALL_SHARDS
func (mdb *db) Exec(stmt string, args ...interface{}) error {
_, err := mdb.driver.Exec(sqlplugin.DbAllShards, stmt, args...)
func (mdb *db) ExecSchemaOperationQuery(ctx context.Context, stmt string, args ...interface{}) error {
_, err := mdb.driver.ExecDDL(ctx, sqlplugin.DbShardUndefined, stmt, args...)
return err
}

// ListTables returns a list of tables in this database
func (mdb *db) ListTables(database string) ([]string, error) {
var tables []string
err := mdb.driver.Select(sqlplugin.DbDefaultShard, &tables, fmt.Sprintf(listTablesQuery, database))
err := mdb.driver.SelectForSchemaQuery(sqlplugin.DbShardUndefined, &tables, fmt.Sprintf(listTablesQuery, database))
return tables, err
}

// DropTable drops a given table from the database
func (mdb *db) DropTable(name string) error {
return mdb.Exec(fmt.Sprintf(dropTableQuery, name))
return mdb.ExecSchemaOperationQuery(context.Background(), fmt.Sprintf(dropTableQuery, name))
}

// DropAllTables drops all tables from this database
Expand All @@ -121,10 +121,10 @@ func (mdb *db) DropAllTables(database string) error {

// CreateDatabase creates a database if it doesn't exist
func (mdb *db) CreateDatabase(name string) error {
return mdb.Exec(fmt.Sprintf(createDatabaseQuery, name))
return mdb.ExecSchemaOperationQuery(context.Background(), fmt.Sprintf(createDatabaseQuery, name))
}

// DropDatabase drops a database
func (mdb *db) DropDatabase(name string) error {
return mdb.Exec(fmt.Sprintf(dropDatabaseQuery, name))
return mdb.ExecSchemaOperationQuery(context.Background(), fmt.Sprintf(dropDatabaseQuery, name))
}
26 changes: 13 additions & 13 deletions common/persistence/sql/sqlplugin/postgres/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package postgres

import (
"context"
"fmt"
"time"

Expand Down Expand Up @@ -67,48 +68,47 @@ const (

// CreateSchemaVersionTables sets up the schema version tables
func (pdb *db) CreateSchemaVersionTables() error {
if err := pdb.Exec(createSchemaVersionTableQuery); err != nil {
if err := pdb.ExecSchemaOperationQuery(context.Background(), createSchemaVersionTableQuery); err != nil {
return err
}
return pdb.Exec(createSchemaUpdateHistoryTableQuery)
return pdb.ExecSchemaOperationQuery(context.Background(), createSchemaUpdateHistoryTableQuery)
}

// ReadSchemaVersion returns the current schema version for the keyspace
func (pdb *db) ReadSchemaVersion(database string) (string, error) {
var version string
err := pdb.driver.Get(sqlplugin.DbDefaultShard, &version, readSchemaVersionQuery, database)
err := pdb.driver.GetForSchemaQuery(sqlplugin.DbShardUndefined, &version, readSchemaVersionQuery, database)
return version, err
}

// UpdateSchemaVersion updates the schema version for the keyspace
func (pdb *db) UpdateSchemaVersion(database string, newVersion string, minCompatibleVersion string) error {
return pdb.Exec(writeSchemaVersionQuery, database, time.Now(), newVersion, minCompatibleVersion)
return pdb.ExecSchemaOperationQuery(context.Background(), writeSchemaVersionQuery, database, time.Now(), newVersion, minCompatibleVersion)
}

// WriteSchemaUpdateLog adds an entry to the schema update history table
func (pdb *db) WriteSchemaUpdateLog(oldVersion string, newVersion string, manifestMD5 string, desc string) error {
now := time.Now().UTC()
return pdb.Exec(writeSchemaUpdateHistoryQuery, now.Year(), int(now.Month()), now, oldVersion, newVersion, manifestMD5, desc)
return pdb.ExecSchemaOperationQuery(context.Background(), writeSchemaUpdateHistoryQuery, now.Year(), int(now.Month()), now, oldVersion, newVersion, manifestMD5, desc)
}

// Exec executes a sql statement
// ExecSchemaOperationQuery executes a sql statement for schema ONLY. DO NOT use it in other cases, otherwise it will not work for multiple SQL database.
// For Sharded SQL, it will execute the statement for all shards
// TODO: rename to ExecSchemaQuery so that we know it should use DB_ALL_SHARDS
func (pdb *db) Exec(stmt string, args ...interface{}) error {
_, err := pdb.driver.Exec(sqlplugin.DbAllShards, stmt, args...)
func (pdb *db) ExecSchemaOperationQuery(ctx context.Context, stmt string, args ...interface{}) error {
_, err := pdb.driver.ExecDDL(ctx, sqlplugin.DbShardUndefined, stmt, args...)
return err
}

// ListTables returns a list of tables in this database
func (pdb *db) ListTables(database string) ([]string, error) {
var tables []string
err := pdb.driver.Select(sqlplugin.DbDefaultShard, &tables, listTablesQuery)
err := pdb.driver.SelectForSchemaQuery(sqlplugin.DbShardUndefined, &tables, listTablesQuery)
return tables, err
}

// DropTable drops a given table from the database
func (pdb *db) DropTable(name string) error {
return pdb.Exec(fmt.Sprintf(dropTableQuery, name))
return pdb.ExecSchemaOperationQuery(context.Background(), fmt.Sprintf(dropTableQuery, name))
}

// DropAllTables drops all tables from this database
Expand All @@ -127,10 +127,10 @@ func (pdb *db) DropAllTables(database string) error {

// CreateDatabase creates a database if it doesn't exist
func (pdb *db) CreateDatabase(name string) error {
return pdb.Exec(fmt.Sprintf(createDatabaseQuery, name))
return pdb.ExecSchemaOperationQuery(context.Background(), fmt.Sprintf(createDatabaseQuery, name))
}

// DropDatabase drops a database
func (pdb *db) DropDatabase(name string) error {
return pdb.Exec(fmt.Sprintf(dropDatabaseQuery, name))
return pdb.ExecSchemaOperationQuery(context.Background(), fmt.Sprintf(dropDatabaseQuery, name))
}
18 changes: 9 additions & 9 deletions tools/cassandra/cqlclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ const (
`WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %v};`
)

var _ schema.DB = (*CqlClient)(nil)
var _ schema.SchemaClient = (*CqlClient)(nil)

// NewCQLClient returns a new instance of CQLClient
func NewCQLClient(cfg *CQLClientConfig) (*CqlClient, error) {
Expand Down Expand Up @@ -123,12 +123,12 @@ func (client *CqlClient) DropDatabase(name string) error {

// createKeyspace creates a cassandra Keyspace if it doesn't exist
func (client *CqlClient) CreateKeyspace(name string) error {
return client.Exec(fmt.Sprintf(createKeyspaceCQL, name, client.nReplicas))
return client.ExecDDLQuery(fmt.Sprintf(createKeyspaceCQL, name, client.nReplicas))
}

// DropKeyspace drops a Keyspace
func (client *CqlClient) DropKeyspace(name string) error {
return client.Exec(fmt.Sprintf("DROP KEYSPACE %v", name))
return client.ExecDDLQuery(fmt.Sprintf("DROP KEYSPACE %v", name))
}

func (client *CqlClient) DropAllTables() error {
Expand All @@ -137,10 +137,10 @@ func (client *CqlClient) DropAllTables() error {

// CreateSchemaVersionTables sets up the schema version tables
func (client *CqlClient) CreateSchemaVersionTables() error {
if err := client.Exec(createSchemaVersionTableCQL); err != nil {
if err := client.ExecDDLQuery(createSchemaVersionTableCQL); err != nil {
return err
}
return client.Exec(createSchemaUpdateHistoryTableCQL)
return client.ExecDDLQuery(createSchemaUpdateHistoryTableCQL)
}

// ReadSchemaVersion returns the current schema version for the Keyspace
Expand Down Expand Up @@ -172,8 +172,8 @@ func (client *CqlClient) WriteSchemaUpdateLog(oldVersion string, newVersion stri
return query.Exec()
}

// Exec executes a cql statement
func (client *CqlClient) Exec(stmt string, args ...interface{}) error {
// ExecDDLQuery executes a cql statement
func (client *CqlClient) ExecDDLQuery(stmt string, args ...interface{}) error {
return client.session.Query(stmt, args...).Exec()
}

Expand Down Expand Up @@ -216,12 +216,12 @@ func (client *CqlClient) listTypes() ([]string, error) {

// dropTable drops a given table from the Keyspace
func (client *CqlClient) dropTable(name string) error {
return client.Exec(fmt.Sprintf("DROP TABLE %v", name))
return client.ExecDDLQuery(fmt.Sprintf("DROP TABLE %v", name))
}

// dropType drops a given type from the Keyspace
func (client *CqlClient) dropType(name string) error {
return client.Exec(fmt.Sprintf("DROP TYPE %v", name))
return client.ExecDDLQuery(fmt.Sprintf("DROP TYPE %v", name))
}

// dropAllTablesTypes deletes all tables/types in the
Expand Down
Loading

0 comments on commit 085a799

Please sign in to comment.