Skip to content

Commit

Permalink
*: refine text protocol multiple query response (pingcap#11263)
Browse files Browse the repository at this point in the history
  • Loading branch information
lysu authored and jackysp committed Jul 17, 2019
1 parent 23d4c97 commit 77b6858
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 9 deletions.
34 changes: 27 additions & 7 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import (
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/sqlexec"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -945,18 +946,22 @@ func (cc *clientConn) flush() error {

func (cc *clientConn) writeOK() error {
msg := cc.ctx.LastMessage()
return cc.writeOkWith(msg, cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), cc.ctx.Status(), cc.ctx.WarningCount())
}

func (cc *clientConn) writeOkWith(msg string, affectedRows, lastInsertID uint64, status, warnCnt uint16) error {
enclen := 0
if len(msg) > 0 {
enclen = lengthEncodedIntSize(uint64(len(msg))) + len(msg)
}

data := cc.alloc.AllocWithLen(4, 32+enclen)
data = append(data, mysql.OKHeader)
data = dumpLengthEncodedInt(data, cc.ctx.AffectedRows())
data = dumpLengthEncodedInt(data, cc.ctx.LastInsertID())
data = dumpLengthEncodedInt(data, affectedRows)
data = dumpLengthEncodedInt(data, lastInsertID)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dumpUint16(data, cc.ctx.Status())
data = dumpUint16(data, cc.ctx.WarningCount())
data = dumpUint16(data, status)
data = dumpUint16(data, warnCnt)
}
if enclen > 0 {
// although MySQL manual says the info message is string<EOF>(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html),
Expand Down Expand Up @@ -1403,12 +1408,27 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet
}

func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error {
for _, rs := range rss {
if err := cc.writeResultset(ctx, rs, binary, mysql.ServerMoreResultsExists, 0); err != nil {
for i, rs := range rss {
lastRs := i == len(rss)-1
if r, ok := rs.(*tidbResultSet).recordSet.(sqlexec.MultiQueryNoDelayResult); ok {
status := r.Status()
if !lastRs {
status |= mysql.ServerMoreResultsExists
}
if err := cc.writeOkWith(r.LastMessage(), r.AffectedRows(), r.LastInsertID(), status, r.WarnCount()); err != nil {
return err
}
continue
}
status := uint16(0)
if !lastRs {
status |= mysql.ServerMoreResultsExists
}
if err := cc.writeResultset(ctx, rs, binary, status, 0); err != nil {
return err
}
}
return cc.writeOK()
return nil
}

func (cc *clientConn) setConn(conn net.Conn) {
Expand Down
59 changes: 57 additions & 2 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
s.processInfo.Store(&pi)
}

func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet) ([]sqlexec.RecordSet, error) {
func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet, inMulitQuery bool) ([]sqlexec.RecordSet, error) {
s.SetValue(sessionctx.QueryString, stmt.OriginText())
if _, ok := stmtNode.(ast.DDLNode); ok {
s.SetValue(sessionctx.LastExecuteDDL, true)
Expand All @@ -1016,6 +1016,16 @@ func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode
sessionExecuteRunDurationGeneral.Observe(time.Since(startTime).Seconds())
}

if inMulitQuery && recordSet == nil {
recordSet = &multiQueryNoDelayRecordSet{
affectedRows: s.AffectedRows(),
lastMessage: s.LastMessage(),
warnCount: s.sessionVars.StmtCtx.WarningCount(),
lastInsertID: s.sessionVars.StmtCtx.LastInsertID,
status: s.sessionVars.Status,
}
}

if recordSet != nil {
recordSets = append(recordSets, recordSet)
}
Expand Down Expand Up @@ -1062,6 +1072,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec

var tempStmtNodes []ast.StmtNode
compiler := executor.Compiler{Ctx: s}
multiQuery := len(stmtNodes) > 1
for idx, stmtNode := range stmtNodes {
s.PrepareTxnCtx(ctx)

Expand Down Expand Up @@ -1098,7 +1109,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec
s.currentPlan = stmt.Plan

// Step3: Execute the physical plan.
if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets); err != nil {
if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets, multiQuery); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1952,3 +1963,47 @@ func (s *session) recordTransactionCounter(err error) {
}
}
}

type multiQueryNoDelayRecordSet struct {
affectedRows uint64
lastMessage string
status uint16
warnCount uint16
lastInsertID uint64
}

func (c *multiQueryNoDelayRecordSet) Fields() []*ast.ResultField {
panic("unsupported method")
}

func (c *multiQueryNoDelayRecordSet) Next(ctx context.Context, chk *chunk.Chunk) error {
panic("unsupported method")
}

func (c *multiQueryNoDelayRecordSet) NewChunk() *chunk.Chunk {
panic("unsupported method")
}

func (c *multiQueryNoDelayRecordSet) Close() error {
return nil
}

func (c *multiQueryNoDelayRecordSet) AffectedRows() uint64 {
return c.affectedRows
}

func (c *multiQueryNoDelayRecordSet) LastMessage() string {
return c.lastMessage
}

func (c *multiQueryNoDelayRecordSet) WarnCount() uint16 {
return c.warnCount
}

func (c *multiQueryNoDelayRecordSet) Status() uint16 {
return c.status
}

func (c *multiQueryNoDelayRecordSet) LastInsertID() uint64 {
return c.lastInsertID
}
14 changes: 14 additions & 0 deletions util/sqlexec/restricted_sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,17 @@ type RecordSet interface {
// restart the iteration.
Close() error
}

// MultiQueryNoDelayResult is an interface for one no-delay result for one statement in multi-queries.
type MultiQueryNoDelayResult interface {
// AffectedRows return affected row for one statement in multi-queries.
AffectedRows() uint64
// LastMessage return last message for one statement in multi-queries.
LastMessage() string
// WarnCount return warn count for one statement in multi-queries.
WarnCount() uint16
// Status return status when executing one statement in multi-queries.
Status() uint16
// LastInsertID return last insert id for one statement in multi-queries.
LastInsertID() uint64
}

0 comments on commit 77b6858

Please sign in to comment.