Skip to content

Commit

Permalink
server: add shutdown, refresh, resetconnection RPCs (pingcap#18893)
Browse files Browse the repository at this point in the history
* server: add shutdown and refresh RPCs

There are several legacy RPC commands which TiDB does not support. This reorganizes the main dispatcher to follow the same order as the protocol to make it clearer which ones are missing.

Support for ComRefresh (often used with mysqldump -F to rotate binary logs) and ComShutdown (legacy, but may still be used by an old client) are added.

I attempted to add ComStatistics, but the RPC is requires a different response packet. So I am going to skip for now.

* Add com_reset_connection rpc

* Add more tests

* Fix reset connection (lost privileges, current db)

Co-authored-by: ti-srebot <[email protected]>
  • Loading branch information
Null not nil and ti-srebot authored Aug 11, 2020
1 parent 436f5f1 commit 3c9f790
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 14 deletions.
85 changes: 71 additions & 14 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
return nil
case mysql.ComQuit:
return io.EOF
case mysql.ComInitDB:
if err := cc.useDB(ctx, dataStr); err != nil {
return err
}
return cc.writeOK()
case mysql.ComQuery: // Most frequently used command.
// For issue 1989
// Input payload may end with byte '\0', we didn't find related mysql document about it, but mysql
Expand All @@ -918,31 +923,41 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
dataStr = string(hack.String(data))
}
return cc.handleQuery(ctx, dataStr)
case mysql.ComPing:
return cc.writeOK()
case mysql.ComInitDB:
if err := cc.useDB(ctx, dataStr); err != nil {
case mysql.ComFieldList:
return cc.handleFieldList(dataStr)
// ComCreateDB, ComDropDB
case mysql.ComRefresh:
return cc.handleRefresh(ctx, data[0])
case mysql.ComShutdown: // redirect to SQL
if err := cc.handleQuery(ctx, "SHUTDOWN"); err != nil {
return err
}
return cc.writeOK()
case mysql.ComFieldList:
return cc.handleFieldList(dataStr)
// ComStatistics, ComProcessInfo, ComConnect, ComProcessKill, ComDebug
case mysql.ComPing:
return cc.writeOK()
// ComTime, ComDelayedInsert
case mysql.ComChangeUser:
return cc.handleChangeUser(ctx, data)
// ComBinlogDump, ComTableDump, ComConnectOut, ComRegisterSlave
case mysql.ComStmtPrepare:
return cc.handleStmtPrepare(dataStr)
case mysql.ComStmtExecute:
return cc.handleStmtExecute(ctx, data)
case mysql.ComStmtFetch:
return cc.handleStmtFetch(ctx, data)
case mysql.ComStmtClose:
return cc.handleStmtClose(data)
case mysql.ComStmtSendLongData:
return cc.handleStmtSendLongData(data)
case mysql.ComStmtClose:
return cc.handleStmtClose(data)
case mysql.ComStmtReset:
return cc.handleStmtReset(data)
case mysql.ComSetOption:
return cc.handleSetOption(data)
case mysql.ComChangeUser:
return cc.handleChangeUser(ctx, data)
case mysql.ComStmtFetch:
return cc.handleStmtFetch(ctx, data)
// ComDaemon, ComBinlogDumpGtid
case mysql.ComResetConnection:
return cc.handleResetConnection(ctx)
// ComEnd
default:
return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd)
}
Expand Down Expand Up @@ -1754,6 +1769,7 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
data = data[passLen:]
dbName, _ := parseNullTermString(data)
cc.dbname = string(hack.String(dbName))

err := cc.ctx.Close()
if err != nil {
logutil.Logger(ctx).Debug("close old context failed", zap.Error(err))
Expand All @@ -1762,16 +1778,48 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
if err != nil {
return err
}
return cc.handleCommonConnectionReset()
}

func (cc *clientConn) handleResetConnection(ctx context.Context) error {
user := cc.ctx.GetSessionVars().User
err := cc.ctx.Close()
if err != nil {
logutil.Logger(ctx).Debug("close old context failed", zap.Error(err))
}
var tlsStatePtr *tls.ConnectionState
if cc.tlsConn != nil {
tlsState := cc.tlsConn.ConnectionState()
tlsStatePtr = &tlsState
}
cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, cc.collation, cc.dbname, tlsStatePtr)
if err != nil {
return err
}
if !cc.ctx.AuthWithoutVerification(user) {
return errors.New("Could not reset connection")
}
if cc.dbname != "" { // Restore the current DB
err = cc.useDB(context.Background(), cc.dbname)
if err != nil {
return err
}
}
cc.ctx.SetSessionManager(cc.server)

return cc.handleCommonConnectionReset()
}

func (cc *clientConn) handleCommonConnectionReset() error {
if plugin.IsEnable(plugin.Audit) {
cc.ctx.GetSessionVars().ConnectionInfo = cc.connectInfo()
}

err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
connInfo := cc.ctx.GetSessionVars().ConnectionInfo
err = authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo)
err := authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo)
if err != nil {
return err
}
Expand All @@ -1781,7 +1829,16 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
if err != nil {
return err
}
return cc.writeOK()
}

// safe to noop except 0x01 "FLUSH PRIVILEGES"
func (cc *clientConn) handleRefresh(ctx context.Context, subCommand byte) error {
if subCommand == 0x01 {
if err := cc.handleQuery(ctx, "FLUSH PRIVILEGES"); err != nil {
return err
}
}
return cc.writeOK()
}

Expand Down
36 changes: 36 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,24 @@ func (ts *ConnTestSuite) TestDispatch(c *C) {
err: nil,
out: []byte{0x3, 0x0, 0x0, 0xe, 0x0, 0x0, 0x0},
},
{
com: mysql.ComRefresh, // flush privileges
in: []byte{0x01},
err: nil,
out: []byte{0x3, 0x0, 0x0, 0xf, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0},
},
{
com: mysql.ComRefresh, // flush logs etc
in: []byte{0x02},
err: nil,
out: []byte{0x3, 0x0, 0x0, 0x11, 0x0, 0x0, 0x0},
},
{
com: mysql.ComResetConnection,
in: nil,
err: nil,
out: []byte{0x3, 0x0, 0x0, 0x12, 0x0, 0x0, 0x0},
},
}

ts.testDispatch(c, inputs, 0)
Expand Down Expand Up @@ -401,6 +419,24 @@ func (ts *ConnTestSuite) TestDispatchClientProtocol41(c *C) {
err: nil,
out: []byte{0x7, 0x0, 0x0, 0xe, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0},
},
{
com: mysql.ComRefresh, // flush privileges
in: []byte{0x01},
err: nil,
out: []byte{0x7, 0x0, 0x0, 0xf, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0},
},
{
com: mysql.ComRefresh, // flush logs etc
in: []byte{0x02},
err: nil,
out: []byte{0x7, 0x0, 0x0, 0x11, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0},
},
{
com: mysql.ComResetConnection,
in: nil,
err: nil,
out: []byte{0x7, 0x0, 0x0, 0x12, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0},
},
}

ts.testDispatch(c, inputs, mysql.ClientProtocol41)
Expand Down
33 changes: 33 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ type Session interface {
SetSessionManager(util.SessionManager)
Close()
Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool
AuthWithoutVerification(user *auth.UserIdentity) bool
ShowProcess() *util.ProcessInfo
// PrepareTxnCtx is exported for test.
PrepareTxnCtx(context.Context)
Expand Down Expand Up @@ -1609,6 +1610,38 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by
return false
}

// AuthWithoutVerification is required by the ResetConnection RPC
func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool {
pm := privilege.GetPrivilegeManager(s)

// Check IP or localhost.
var success bool
user.AuthUsername, user.AuthHostname, success = pm.GetAuthWithoutVerification(user.Username, user.Hostname)
if success {
s.sessionVars.User = user
s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname)
return true
} else if user.Hostname == variable.DefHostname {
return false
}

// Check Hostname.
for _, addr := range getHostByIP(user.Hostname) {
u, h, success := pm.GetAuthWithoutVerification(user.Username, addr)
if success {
s.sessionVars.User = &auth.UserIdentity{
Username: user.Username,
Hostname: addr,
AuthUsername: u,
AuthHostname: h,
}
s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h)
return true
}
}
return false
}

func getHostByIP(ip string) []string {
if ip == "127.0.0.1" {
return []string{variable.DefHostname}
Expand Down

0 comments on commit 3c9f790

Please sign in to comment.