Skip to content

Commit

Permalink
lightning: support base64 encoding of password (pingcap#31195)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepymole authored Mar 10, 2022
1 parent 27eed7a commit ccdd432
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 4 deletions.
2 changes: 1 addition & 1 deletion br/pkg/lightning/checkpoints/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) {

switch cfg.Checkpoint.Driver {
case config.CheckpointDriverMySQL:
db, err := sql.Open("mysql", cfg.Checkpoint.DSN)
db, err := common.ConnectMySQL(cfg.Checkpoint.DSN)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
52 changes: 49 additions & 3 deletions br/pkg/lightning/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package common
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
Expand All @@ -27,9 +28,12 @@ import (
"syscall"
"time"

"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/br/pkg/lightning/log"
"github.com/pingcap/tidb/br/pkg/utils"
tmysql "github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/model"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -64,13 +68,55 @@ func (param *MySQLConnectParam) ToDSN() string {
return dsn
}

func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
db, err := sql.Open("mysql", param.ToDSN())
func tryConnectMySQL(dsn string) (*sql.DB, error) {
driverName := "mysql"
failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) {
driverName = val.(string)
})
db, err := sql.Open(driverName, dsn)
if err != nil {
return nil, errors.Trace(err)
}
if err = db.Ping(); err != nil {
_ = db.Close()
return nil, errors.Trace(err)
}
return db, nil
}

// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding,
// we will try to connect MySQL with the base64 decoding of the password.
func ConnectMySQL(dsn string) (*sql.DB, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, errors.Trace(err)
}
// Try plain password first.
db, firstErr := tryConnectMySQL(dsn)
if firstErr == nil {
return db, nil
}
// If access is denied and password is encoded by base64, try the decoded string as well.
if mysqlErr, ok := errors.Cause(firstErr).(*mysql.MySQLError); ok && mysqlErr.Number == tmysql.ErrAccessDenied {
// If password is encoded by base64, try the decoded string as well.
if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd {
cfg.Passwd = string(password)
db, err = tryConnectMySQL(cfg.FormatDSN())
if err == nil {
return db, nil
}
}
}
// If we can't connect successfully, return the first error.
return nil, errors.Trace(firstErr)
}

return db, errors.Trace(db.Ping())
func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
db, err := ConnectMySQL(param.ToDSN())
if err != nil {
return nil, errors.Trace(err)
}
return db, nil
}

// IsDirExists checks if dir exists.
Expand Down
69 changes: 69 additions & 0 deletions br/pkg/lightning/common/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,26 @@ package common_test

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/log"
tmysql "github.com/pingcap/tidb/errno"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -92,6 +101,66 @@ func TestToDSN(t *testing.T) {
require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())
}

type mockDriver struct {
driver.Driver
plainPsw string
}

func (m *mockDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := mysql.ParseDSN(dsn)
if err != nil {
return nil, err
}
accessDenied := cfg.Passwd != m.plainPsw
return &mockConn{accessDenied: accessDenied}, nil
}

type mockConn struct {
driver.Conn
driver.Pinger
accessDenied bool
}

func (c *mockConn) Ping(ctx context.Context) error {
if c.accessDenied {
return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}
}
return nil
}

func (c *mockConn) Close() error {
return nil
}

func TestConnect(t *testing.T) {
plainPsw := "dQAUoDiyb1ucWZk7"
driverName := "mysql-mock-" + strconv.Itoa(rand.Int())
sql.Register(driverName, &mockDriver{plainPsw: plainPsw})

require.NoError(t, failpoint.Enable(
"github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver",
fmt.Sprintf("return(\"%s\")", driverName)))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver"))
}()

param := common.MySQLConnectParam{
Host: "127.0.0.1",
Port: 4000,
User: "root",
Password: plainPsw,
SQLMode: "strict",
MaxAllowedPacket: 1234,
}
db, err := param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw))
db, err = param.Connect()
require.NoError(t, err)
require.NoError(t, db.Close())
}

func TestIsContextCanceledError(t *testing.T) {
require.True(t, common.IsContextCanceledError(context.Canceled))
require.False(t, common.IsContextCanceledError(io.EOF))
Expand Down

0 comments on commit ccdd432

Please sign in to comment.