Skip to content

Commit

Permalink
server: support MULTI_STATEMENTS and MULTI_RESULTS capability (pingca…
Browse files Browse the repository at this point in the history
…p#1628)

only handle COM_QUERY currently, COM_STMT_PREPARE would be supported later
http://dev.mysql.com/doc/internals/en/capability-flags.html#flag-CLIENT_MULTI_STATEMENTS

implement set option in client/server protocol
http://dev.mysql.com/doc/internals/en/com-set-option.html
  • Loading branch information
tiancaiamao authored Aug 26, 2016
1 parent e33587e commit 5ebdc1c
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 15 deletions.
43 changes: 36 additions & 7 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ import (

var defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows
mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows |
mysql.ClientMultiStatements | mysql.ClientMultiResults

type clientConn struct {
pkg *packetIO
Expand Down Expand Up @@ -291,6 +292,8 @@ func (cc *clientConn) dispatch(data []byte) error {
return cc.handleStmtSendLongData(data)
case mysql.ComStmtReset:
return cc.handleStmtReset(data)
case mysql.ComSetOption:
return cc.handleSetOption(data)
default:
return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd)
}
Expand Down Expand Up @@ -357,12 +360,21 @@ func (cc *clientConn) writeError(e error) error {
return errors.Trace(cc.flush())
}

func (cc *clientConn) writeEOF() error {
// writeEOF writes an EOF packet.
// Note this function won't flush the stream because maybe there are more
// packets following it, the "more" argument would indicates that case.
// If "more" is true, a mysql.ServerMoreResultsExists bit would be set
// in the packet.
func (cc *clientConn) writeEOF(more bool) error {
data := cc.alloc.AllocWithLen(4, 9)

data = append(data, mysql.EOFHeader)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = append(data, dumpUint16(cc.ctx.WarningCount())...)
status := cc.ctx.Status()
if more {
status |= mysql.ServerMoreResultsExists
}
data = append(data, dumpUint16(cc.ctx.Status())...)
}

Expand All @@ -377,7 +389,11 @@ func (cc *clientConn) handleQuery(sql string) (err error) {
return errors.Trace(err)
}
if rs != nil {
err = cc.writeResultset(rs, false)
if len(rs) == 1 {
err = cc.writeResultset(rs[0], false, false)
} else {
err = cc.writeMultiResultset(rs, false)
}
} else {
err = cc.writeOK()
}
Expand Down Expand Up @@ -408,13 +424,17 @@ func (cc *clientConn) handleFieldList(sql string) (err error) {
return errors.Trace(err)
}
}
if err := cc.writeEOF(); err != nil {
if err := cc.writeEOF(false); err != nil {
return errors.Trace(err)
}
return errors.Trace(cc.flush())
}

func (cc *clientConn) writeResultset(rs ResultSet, binary bool) error {
// writeResultset writes a resultset.
// If binary is true, the data would be encoded in BINARY format.
// If more is true, a flag bit would be set to indicate there are more
// resultsets, it's used to support the MULTI_RESULTS capability in mysql protocol.
func (cc *clientConn) writeResultset(rs ResultSet, binary bool, more bool) error {
defer rs.Close()
// We need to call Next before we get columns.
// Otherwise, we will get incorrect columns info.
Expand Down Expand Up @@ -442,7 +462,7 @@ func (cc *clientConn) writeResultset(rs ResultSet, binary bool) error {
}
}

if err = cc.writeEOF(); err != nil {
if err = cc.writeEOF(false); err != nil {
return errors.Trace(err)
}

Expand Down Expand Up @@ -482,10 +502,19 @@ func (cc *clientConn) writeResultset(rs ResultSet, binary bool) error {
row, err = rs.Next()
}

err = cc.writeEOF()
err = cc.writeEOF(more)
if err != nil {
return errors.Trace(err)
}

return errors.Trace(cc.flush())
}

func (cc *clientConn) writeMultiResultset(rss []ResultSet, binary bool) error {
for _, rs := range rss {
if err := cc.writeResultset(rs, binary, true); err != nil {
return errors.Trace(err)
}
}
return cc.writeOK()
}
29 changes: 26 additions & 3 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (cc *clientConn) handleStmtPrepare(sql string) error {
}
}

if err := cc.writeEOF(); err != nil {
if err := cc.writeEOF(false); err != nil {
return errors.Trace(err)
}
}
Expand All @@ -93,7 +93,7 @@ func (cc *clientConn) handleStmtPrepare(sql string) error {
}
}

if err := cc.writeEOF(); err != nil {
if err := cc.writeEOF(false); err != nil {
return errors.Trace(err)
}

Expand Down Expand Up @@ -166,7 +166,7 @@ func (cc *clientConn) handleStmtExecute(data []byte) (err error) {
return errors.Trace(cc.writeOK())
}

return errors.Trace(cc.writeResultset(rs, true))
return errors.Trace(cc.writeResultset(rs, true, false))
}

func parseStmtArgs(args []interface{}, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) {
Expand Down Expand Up @@ -345,3 +345,26 @@ func (cc *clientConn) handleStmtReset(data []byte) (err error) {
stmt.Reset()
return cc.writeOK()
}

// See https://dev.mysql.com/doc/internals/en/com-set-option.html
func (cc *clientConn) handleSetOption(data []byte) (err error) {
if len(data) < 2 {
return mysql.ErrMalformPacket
}

switch binary.LittleEndian.Uint16(data[:2]) {
case 0:
cc.capability |= mysql.ClientMultiStatements
cc.ctx.SetClientCapability(cc.capability)
case 1:
cc.capability &^= mysql.ClientMultiStatements
cc.ctx.SetClientCapability(cc.capability)
default:
return mysql.ErrMalformPacket
}
if err = cc.writeEOF(false); err != nil {
return errors.Trace(err)
}

return errors.Trace(cc.flush())
}
5 changes: 4 additions & 1 deletion server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ type IContext interface {
CurrentDB() string

// Execute executes a SQL statement.
Execute(sql string) (ResultSet, error)
Execute(sql string) ([]ResultSet, error)

// SetClientCapability sets client capability flags
SetClientCapability(uint32)

// Prepare prepares a statement.
Prepare(sql string) (statement IStatement, columns, params []*ColumnInfo, err error)
Expand Down
16 changes: 12 additions & 4 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,28 @@ func (tc *TiDBContext) WarningCount() uint16 {
}

// Execute implements IContext Execute method.
func (tc *TiDBContext) Execute(sql string) (rs ResultSet, err error) {
func (tc *TiDBContext) Execute(sql string) (rs []ResultSet, err error) {
rsList, err := tc.session.Execute(sql)
if err != nil {
return
}
if len(rsList) == 0 { // result ok
return
}
rs = &tidbResultSet{
recordSet: rsList[0],
rs = make([]ResultSet, len(rsList))
for i := 0; i < len(rsList); i++ {
rs[i] = &tidbResultSet{
recordSet: rsList[i],
}
}
return
}

// SetClientCapability implements IContext SetClientCapability method.
func (tc *TiDBContext) SetClientCapability(flags uint32) {
tc.session.SetClientCapability(flags)
}

// Close implements IContext Close method.
func (tc *TiDBContext) Close() (err error) {
return tc.session.Close()
Expand All @@ -183,7 +191,7 @@ func (tc *TiDBContext) FieldList(table string) (colums []*ColumnInfo, err error)
if err != nil {
return nil, errors.Trace(err)
}
colums, err = rs.Columns()
colums, err = rs[0].Columns()
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
33 changes: 33 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,36 @@ func runTestMultiPacket(c *C) {
}
})
}

func runTestMultiStatements(c *C) {
runTests(c, dsn, func(dbt *DBTest) {
// Create Table
dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")

// Create Data
res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
count, err := res.RowsAffected()
c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error"))
c.Assert(count, Equals, int64(1))

// Update
res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
count, err = res.RowsAffected()
c.Assert(err, IsNil, Commentf("res.RowsAffected() returned error"))
c.Assert(count, Equals, int64(1))

// Read
var out int
rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
if rows.Next() {
rows.Scan(&out)
c.Assert(out, Equals, 5)

if rows.Next() {
dbt.Error("unexpected data")
}
} else {
dbt.Error("no data")
}
})
}
4 changes: 4 additions & 0 deletions server/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ func (ts *TidbTestSuite) TestMultiPacket(c *C) {
runTestMultiPacket(c)
}

func (ts *TidbTestSuite) TestMultiStatements(c *C) {
runTestMultiStatements(c)
}

func (ts *TidbTestSuite) TestSocket(c *C) {
cfg := &Config{
LogLevel: "debug",
Expand Down
5 changes: 5 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ func (s *session) Execute(sql string) ([]ast.RecordSet, error) {
rs = append(rs, r)
}
}

if variable.GetSessionVars(s).ClientCapability&mysql.ClientMultiResults == 0 && len(rs) > 1 {
// return the first recordset if client doesn't support ClientMultiResults.
rs = rs[:1]
}
return rs, nil
}

Expand Down

0 comments on commit 5ebdc1c

Please sign in to comment.