From 865d501d07fc7461aa769741f4574351d9ecd993 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Tue, 13 Dec 2016 17:28:39 +0200 Subject: [PATCH] Add TxPipeline. --- cluster.go | 267 ++++++++++++++++++++------- cluster_test.go | 117 +++++++----- command.go | 251 +++++++------------------ internal/errors.go | 4 + internal/proto/reader.go | 12 +- internal/proto/{proto.go => scan.go} | 0 iterator.go | 1 - pipeline.go | 4 +- pipeline_test.go | 162 ++++------------ race_test.go | 31 ++++ redis.go | 194 +++++++++++++------ ring.go | 8 +- tx.go | 94 +--------- 13 files changed, 566 insertions(+), 579 deletions(-) rename internal/proto/{proto.go => scan.go} (100%) diff --git a/cluster.go b/cluster.go index f9a4fefb7..e81a7eb83 100644 --- a/cluster.go +++ b/cluster.go @@ -1,6 +1,7 @@ package redis import ( + "fmt" "math/rand" "sync" "sync/atomic" @@ -9,6 +10,7 @@ import ( "gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal/hashtag" "gopkg.in/redis.v5/internal/pool" + "gopkg.in/redis.v5/internal/proto" ) var errClusterNoNodes = internal.RedisError("redis: cluster has no nodes") @@ -417,10 +419,6 @@ func (c *ClusterClient) Process(cmd Cmder) error { var ask bool for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { - if attempt > 0 { - cmd.reset() - } - if ask { pipe := node.Client.Pipeline() pipe.Process(NewCmd("ASKING")) @@ -655,111 +653,252 @@ func (c *ClusterClient) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { } func (c *ClusterClient) pipelineExec(cmds []Cmder) error { - var firstErr error - setFirstErr := func(err error) { - if firstErr == nil { - firstErr = err - } - } - - state := c.state() - cmdsMap := make(map[*clusterNode][]Cmder) - for _, cmd := range cmds { - _, node, err := c.cmdSlotAndNode(state, cmd) - if err != nil { - cmd.setErr(err) - setFirstErr(err) - continue - } - cmdsMap[node] = append(cmdsMap[node], cmd) + cmdsMap, err := c.mapCmdsByNode(cmds) + if err != nil { + return err } - for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + for i := 0; i <= c.opt.MaxRedirects; i++ { failedCmds := make(map[*clusterNode][]Cmder) for node, cmds := range cmdsMap { cn, _, err := node.Client.conn() if err != nil { setCmdsErr(cmds, err) - setFirstErr(err) continue } - failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds) + err = c.pipelineProcessCmds(cn, cmds, failedCmds) node.Client.putConn(cn, err, false) - if err != nil { - setFirstErr(err) - } } + if len(failedCmds) == 0 { + break + } cmdsMap = failedCmds } + var firstErr error + for _, cmd := range cmds { + if err := cmd.Err(); err != nil { + firstErr = err + break + } + } return firstErr } -func (c *ClusterClient) execClusterCmds( +func (c *ClusterClient) mapCmdsByNode(cmds []Cmder) (map[*clusterNode][]Cmder, error) { + state := c.state() + cmdsMap := make(map[*clusterNode][]Cmder) + for _, cmd := range cmds { + _, node, err := c.cmdSlotAndNode(state, cmd) + if err != nil { + return nil, err + } + cmdsMap[node] = append(cmdsMap[node], cmd) + } + return cmdsMap, nil +} + +func (c *ClusterClient) pipelineProcessCmds( cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, -) (map[*clusterNode][]Cmder, error) { +) error { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmds...); err != nil { setCmdsErr(cmds, err) - return failedCmds, err + return err } + // Set read timeout for all commands. + cn.SetReadTimeout(c.opt.ReadTimeout) + + return c.pipelineReadCmds(cn, cmds, failedCmds) +} + +func (c *ClusterClient) pipelineReadCmds( + cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, +) error { var firstErr error - setFirstErr := func(err error) { + for _, cmd := range cmds { + err := cmd.readReply(cn) + if err == nil { + continue + } + if firstErr == nil { firstErr = err } + + err = c.checkMovedErr(cmd, failedCmds) + if err != nil && firstErr == nil { + firstErr = err + } } + return firstErr +} - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) +func (c *ClusterClient) checkMovedErr(cmd Cmder, failedCmds map[*clusterNode][]Cmder) error { + moved, ask, addr := internal.IsMovedError(cmd.Err()) + if moved { + c.lazyReloadSlots() - for i, cmd := range cmds { - err := cmd.readReply(cn) - if err == nil { + node, err := c.nodes.Get(addr) + if err != nil { + return err + } + + failedCmds[node] = append(failedCmds[node], cmd) + } + if ask { + node, err := c.nodes.Get(addr) + if err != nil { + return err + } + + failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd) + } + return nil +} + +func (c *ClusterClient) TxPipeline() *Pipeline { + pipe := Pipeline{ + exec: c.txPipelineExec, + } + pipe.cmdable.process = pipe.Process + pipe.statefulCmdable.process = pipe.Process + return &pipe +} + +func (c *ClusterClient) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) { + return c.Pipeline().pipelined(fn) +} + +func (c *ClusterClient) txPipelineExec(cmds []Cmder) error { + cmdsMap, err := c.mapCmdsBySlot(cmds) + if err != nil { + return err + } + + for slot, cmds := range cmdsMap { + node, err := c.state().slotMasterNode(slot) + if err != nil { + setCmdsErr(cmds, err) continue } - if i == 0 && internal.IsRetryableError(err) { - node, err := c.nodes.Random() - if err != nil { - setFirstErr(err) - continue + cmdsMap := map[*clusterNode][]Cmder{node: cmds} + for i := 0; i <= c.opt.MaxRedirects; i++ { + failedCmds := make(map[*clusterNode][]Cmder) + + for node, cmds := range cmdsMap { + cn, _, err := node.Client.conn() + if err != nil { + setCmdsErr(cmds, err) + continue + } + + err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) + node.Client.putConn(cn, err, false) } - cmd.reset() - failedCmds[node] = append(failedCmds[node], cmds...) + if len(failedCmds) == 0 { + break + } + cmdsMap = failedCmds + } + } + + var firstErr error + for _, cmd := range cmds { + if err := cmd.Err(); err != nil { + firstErr = err break } + } + return firstErr +} - moved, ask, addr := internal.IsMovedError(err) - if moved { - c.lazyReloadSlots() +func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) (map[int][]Cmder, error) { + state := c.state() + cmdsMap := make(map[int][]Cmder) + for _, cmd := range cmds { + slot, _, err := c.cmdSlotAndNode(state, cmd) + if err != nil { + return nil, err + } + cmdsMap[slot] = append(cmdsMap[slot], cmd) + } + return cmdsMap, nil +} - node, err := c.nodes.Get(addr) - if err != nil { - setFirstErr(err) - continue - } +func (c *ClusterClient) txPipelineProcessCmds( + node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, +) error { + cn.SetWriteTimeout(c.opt.WriteTimeout) + if err := txPipelineWriteMulti(cn, cmds); err != nil { + setCmdsErr(cmds, err) + failedCmds[node] = cmds + return err + } - cmd.reset() - failedCmds[node] = append(failedCmds[node], cmd) - } else if ask { - node, err := c.nodes.Get(addr) - if err != nil { - setFirstErr(err) - continue - } + // Set read timeout for all commands. + cn.SetReadTimeout(c.opt.ReadTimeout) - cmd.reset() - failedCmds[node] = append(failedCmds[node], NewCmd("ASKING"), cmd) - } else { - setFirstErr(err) + if err := c.txPipelineReadQueued(cn, cmds, failedCmds); err != nil { + return err + } + + _, err := pipelineReadCmds(cn, cmds) + return err +} + +func (c *ClusterClient) txPipelineReadQueued( + cn *pool.Conn, cmds []Cmder, failedCmds map[*clusterNode][]Cmder, +) error { + var firstErr error + + // Parse queued replies. + var statusCmd StatusCmd + if err := statusCmd.readReply(cn); err != nil && firstErr == nil { + firstErr = err + } + + for _, cmd := range cmds { + err := statusCmd.readReply(cn) + if err == nil { + continue + } + + cmd.setErr(err) + if firstErr == nil { + firstErr = err + } + + err = c.checkMovedErr(cmd, failedCmds) + if err != nil && firstErr == nil { + firstErr = err } } - return failedCmds, firstErr + // Parse number of replies. + line, err := cn.Rd.ReadLine() + if err != nil { + if err == Nil { + err = TxFailedErr + } + return err + } + + switch line[0] { + case proto.ErrorReply: + return proto.ParseErrorReply(line) + case proto.ArrayReply: + // ok + default: + err := fmt.Errorf("redis: expected '*', but got line %q", line) + return err + } + + return firstErr } diff --git a/cluster_test.go b/cluster_test.go index 340676801..d6707ef0f 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -373,61 +373,86 @@ var _ = Describe("ClusterClient", func() { Expect(n).To(Equal(int64(100))) }) - Describe("pipeline", func() { - It("follows redirects", func() { - slot := hashtag.Slot("A") - Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) + Describe("pipelining", func() { + var pipe *redis.Pipeline - pipe := client.Pipeline() - defer pipe.Close() + assertPipeline := func() { + It("follows redirects", func() { + slot := hashtag.Slot("A") + Expect(client.SwapSlotNodes(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"})) - keys := []string{"A", "B", "C", "D", "E", "F", "G"} + keys := []string{"A", "B", "C", "D", "E", "F", "G"} - for i, key := range keys { - pipe.Set(key, key+"_value", 0) - pipe.Expire(key, time.Duration(i+1)*time.Hour) - } - cmds, err := pipe.Exec() - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(14)) + for i, key := range keys { + pipe.Set(key, key+"_value", 0) + pipe.Expire(key, time.Duration(i+1)*time.Hour) + } + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(14)) - for _, key := range keys { - pipe.Get(key) - pipe.TTL(key) - } - cmds, err = pipe.Exec() - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(14)) - Expect(cmds[0].(*redis.StringCmd).Val()).To(Equal("A_value")) - Expect(cmds[1].(*redis.DurationCmd).Val()).To(BeNumerically("~", 1*time.Hour, time.Second)) - Expect(cmds[6].(*redis.StringCmd).Val()).To(Equal("D_value")) - Expect(cmds[7].(*redis.DurationCmd).Val()).To(BeNumerically("~", 4*time.Hour, time.Second)) - Expect(cmds[12].(*redis.StringCmd).Val()).To(Equal("G_value")) - Expect(cmds[13].(*redis.DurationCmd).Val()).To(BeNumerically("~", 7*time.Hour, time.Second)) - }) + for _, key := range keys { + pipe.Get(key) + pipe.TTL(key) + } + cmds, err = pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(14)) + Expect(cmds[0].(*redis.StringCmd).Val()).To(Equal("A_value")) + Expect(cmds[1].(*redis.DurationCmd).Val()).To(BeNumerically("~", 1*time.Hour, time.Second)) + Expect(cmds[6].(*redis.StringCmd).Val()).To(Equal("D_value")) + Expect(cmds[7].(*redis.DurationCmd).Val()).To(BeNumerically("~", 4*time.Hour, time.Second)) + Expect(cmds[12].(*redis.StringCmd).Val()).To(Equal("G_value")) + Expect(cmds[13].(*redis.DurationCmd).Val()).To(BeNumerically("~", 7*time.Hour, time.Second)) + }) - It("works with missing keys", func() { - Expect(client.Set("A", "A_value", 0).Err()).NotTo(HaveOccurred()) - Expect(client.Set("C", "C_value", 0).Err()).NotTo(HaveOccurred()) + It("works with missing keys", func() { + Expect(client.Set("A", "A_value", 0).Err()).NotTo(HaveOccurred()) + Expect(client.Set("C", "C_value", 0).Err()).NotTo(HaveOccurred()) - var a, b, c *redis.StringCmd - cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error { - a = pipe.Get("A") - b = pipe.Get("B") - c = pipe.Get("C") - return nil + var a, b, c *redis.StringCmd + cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error { + a = pipe.Get("A") + b = pipe.Get("B") + c = pipe.Get("C") + return nil + }) + Expect(err).To(Equal(redis.Nil)) + Expect(cmds).To(HaveLen(3)) + + Expect(a.Err()).NotTo(HaveOccurred()) + Expect(a.Val()).To(Equal("A_value")) + + Expect(b.Err()).To(Equal(redis.Nil)) + Expect(b.Val()).To(Equal("")) + + Expect(c.Err()).NotTo(HaveOccurred()) + Expect(c.Val()).To(Equal("C_value")) + }) + } + + Describe("Pipeline", func() { + BeforeEach(func() { + pipe = client.Pipeline() }) - Expect(err).To(Equal(redis.Nil)) - Expect(cmds).To(HaveLen(3)) - Expect(a.Err()).NotTo(HaveOccurred()) - Expect(a.Val()).To(Equal("A_value")) + AfterEach(func() { + Expect(pipe.Close()).NotTo(HaveOccurred()) + }) + + assertPipeline() + }) - Expect(b.Err()).To(Equal(redis.Nil)) - Expect(b.Val()).To(Equal("")) + Describe("TxPipeline", func() { + BeforeEach(func() { + pipe = client.TxPipeline() + }) + + AfterEach(func() { + Expect(pipe.Close()).NotTo(HaveOccurred()) + }) - Expect(c.Err()).NotTo(HaveOccurred()) - Expect(c.Val()).To(Equal("C_value")) + assertPipeline() }) }) @@ -624,7 +649,7 @@ var _ = Describe("ClusterClient timeout", func() { return client.ForEachNode(func(client *redis.Client) error { return client.Ping().Err() }) - }, pause).ShouldNot(HaveOccurred()) + }, 2*pause).ShouldNot(HaveOccurred()) }) testTimeout() diff --git a/command.go b/command.go index 11150414d..980981319 100644 --- a/command.go +++ b/command.go @@ -36,7 +36,6 @@ type Cmder interface { readReply(*pool.Conn) error setErr(error) - reset() readTimeout() *time.Duration @@ -50,12 +49,6 @@ func setCmdsErr(cmds []Cmder, e error) { } } -func resetCmds(cmds []Cmder) { - for _, cmd := range cmds { - cmd.reset() - } -} - func writeCmd(cn *pool.Conn, cmds ...Cmder) error { cn.Wb.Reset() for _, cmd := range cmds { @@ -167,11 +160,6 @@ func NewCmd(args ...interface{}) *Cmd { } } -func (cmd *Cmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *Cmd) Val() interface{} { return cmd.val } @@ -185,16 +173,13 @@ func (cmd *Cmd) String() string { } func (cmd *Cmd) readReply(cn *pool.Conn) error { - val, err := cn.Rd.ReadReply(sliceParser) - if err != nil { - cmd.err = err + cmd.val, cmd.err = cn.Rd.ReadReply(sliceParser) + if cmd.err != nil { return cmd.err } - if b, ok := val.([]byte); ok { + if b, ok := cmd.val.([]byte); ok { // Bytes must be copied, because underlying memory is reused. cmd.val = string(b) - } else { - cmd.val = val } return nil } @@ -212,11 +197,6 @@ func NewSliceCmd(args ...interface{}) *SliceCmd { return &SliceCmd{baseCmd: cmd} } -func (cmd *SliceCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *SliceCmd) Val() []interface{} { return cmd.val } @@ -230,10 +210,10 @@ func (cmd *SliceCmd) String() string { } func (cmd *SliceCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(sliceParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(sliceParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.([]interface{}) return nil @@ -252,11 +232,6 @@ func NewStatusCmd(args ...interface{}) *StatusCmd { return &StatusCmd{baseCmd: cmd} } -func (cmd *StatusCmd) reset() { - cmd.val = "" - cmd.err = nil -} - func (cmd *StatusCmd) Val() string { return cmd.val } @@ -287,11 +262,6 @@ func NewIntCmd(args ...interface{}) *IntCmd { return &IntCmd{baseCmd: cmd} } -func (cmd *IntCmd) reset() { - cmd.val = 0 - cmd.err = nil -} - func (cmd *IntCmd) Val() int64 { return cmd.val } @@ -326,11 +296,6 @@ func NewDurationCmd(precision time.Duration, args ...interface{}) *DurationCmd { } } -func (cmd *DurationCmd) reset() { - cmd.val = 0 - cmd.err = nil -} - func (cmd *DurationCmd) Val() time.Duration { return cmd.val } @@ -344,10 +309,10 @@ func (cmd *DurationCmd) String() string { } func (cmd *DurationCmd) readReply(cn *pool.Conn) error { - n, err := cn.Rd.ReadIntReply() - if err != nil { - cmd.err = err - return err + var n int64 + n, cmd.err = cn.Rd.ReadIntReply() + if cmd.err != nil { + return cmd.err } cmd.val = time.Duration(n) * cmd.precision return nil @@ -368,11 +333,6 @@ func NewTimeCmd(args ...interface{}) *TimeCmd { } } -func (cmd *TimeCmd) reset() { - cmd.val = time.Time{} - cmd.err = nil -} - func (cmd *TimeCmd) Val() time.Time { return cmd.val } @@ -386,10 +346,10 @@ func (cmd *TimeCmd) String() string { } func (cmd *TimeCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(timeParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(timeParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.(time.Time) return nil @@ -408,11 +368,6 @@ func NewBoolCmd(args ...interface{}) *BoolCmd { return &BoolCmd{baseCmd: cmd} } -func (cmd *BoolCmd) reset() { - cmd.val = false - cmd.err = nil -} - func (cmd *BoolCmd) Val() bool { return cmd.val } @@ -428,27 +383,29 @@ func (cmd *BoolCmd) String() string { var ok = []byte("OK") func (cmd *BoolCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadReply(nil) + var v interface{} + v, cmd.err = cn.Rd.ReadReply(nil) // `SET key value NX` returns nil when key already exists. But // `SETNX key value` returns bool (0/1). So convert nil to bool. // TODO: is this okay? - if err == Nil { + if cmd.err == Nil { cmd.val = false + cmd.err = nil return nil } - if err != nil { - cmd.err = err - return err + if cmd.err != nil { + return cmd.err } - switch vv := v.(type) { + switch v := v.(type) { case int64: - cmd.val = vv == 1 + cmd.val = v == 1 return nil case []byte: - cmd.val = bytes.Equal(vv, ok) + cmd.val = bytes.Equal(v, ok) return nil default: - return fmt.Errorf("got %T, wanted int64 or string", v) + cmd.err = fmt.Errorf("got %T, wanted int64 or string", v) + return cmd.err } } @@ -465,11 +422,6 @@ func NewStringCmd(args ...interface{}) *StringCmd { return &StringCmd{baseCmd: cmd} } -func (cmd *StringCmd) reset() { - cmd.val = "" - cmd.err = nil -} - func (cmd *StringCmd) Val() string { return cmd.val } @@ -515,13 +467,8 @@ func (cmd *StringCmd) String() string { } func (cmd *StringCmd) readReply(cn *pool.Conn) error { - b, err := cn.Rd.ReadBytesReply() - if err != nil { - cmd.err = err - return err - } - cmd.val = string(b) - return nil + cmd.val, cmd.err = cn.Rd.ReadStringReply() + return cmd.err } //------------------------------------------------------------------------------ @@ -537,11 +484,6 @@ func NewFloatCmd(args ...interface{}) *FloatCmd { return &FloatCmd{baseCmd: cmd} } -func (cmd *FloatCmd) reset() { - cmd.val = 0 - cmd.err = nil -} - func (cmd *FloatCmd) Val() float64 { return cmd.val } @@ -572,11 +514,6 @@ func NewStringSliceCmd(args ...interface{}) *StringSliceCmd { return &StringSliceCmd{baseCmd: cmd} } -func (cmd *StringSliceCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *StringSliceCmd) Val() []string { return cmd.val } @@ -590,10 +527,10 @@ func (cmd *StringSliceCmd) String() string { } func (cmd *StringSliceCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(stringSliceParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(stringSliceParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.([]string) return nil @@ -612,11 +549,6 @@ func NewBoolSliceCmd(args ...interface{}) *BoolSliceCmd { return &BoolSliceCmd{baseCmd: cmd} } -func (cmd *BoolSliceCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *BoolSliceCmd) Val() []bool { return cmd.val } @@ -630,10 +562,10 @@ func (cmd *BoolSliceCmd) String() string { } func (cmd *BoolSliceCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(boolSliceParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(boolSliceParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.([]bool) return nil @@ -652,11 +584,6 @@ func NewStringStringMapCmd(args ...interface{}) *StringStringMapCmd { return &StringStringMapCmd{baseCmd: cmd} } -func (cmd *StringStringMapCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *StringStringMapCmd) Val() map[string]string { return cmd.val } @@ -670,10 +597,10 @@ func (cmd *StringStringMapCmd) String() string { } func (cmd *StringStringMapCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(stringStringMapParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(stringStringMapParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.(map[string]string) return nil @@ -704,16 +631,11 @@ func (cmd *StringIntMapCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *StringIntMapCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *StringIntMapCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(stringIntMapParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(stringIntMapParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.(map[string]int64) return nil @@ -732,11 +654,6 @@ func NewZSliceCmd(args ...interface{}) *ZSliceCmd { return &ZSliceCmd{baseCmd: cmd} } -func (cmd *ZSliceCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *ZSliceCmd) Val() []Z { return cmd.val } @@ -750,10 +667,10 @@ func (cmd *ZSliceCmd) String() string { } func (cmd *ZSliceCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(zSliceParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(zSliceParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.([]Z) return nil @@ -775,12 +692,6 @@ func NewScanCmd(args ...interface{}) *ScanCmd { } } -func (cmd *ScanCmd) reset() { - cmd.cursor = 0 - cmd.page = nil - cmd.err = nil -} - func (cmd *ScanCmd) Val() (keys []string, cursor uint64) { return cmd.page, cmd.cursor } @@ -794,14 +705,8 @@ func (cmd *ScanCmd) String() string { } func (cmd *ScanCmd) readReply(cn *pool.Conn) error { - page, cursor, err := cn.Rd.ReadScanReply() - if err != nil { - cmd.err = err - return cmd.err - } - cmd.page = page - cmd.cursor = cursor - return nil + cmd.page, cmd.cursor, cmd.err = cn.Rd.ReadScanReply() + return cmd.err } //------------------------------------------------------------------------------ @@ -840,16 +745,11 @@ func (cmd *ClusterSlotsCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *ClusterSlotsCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *ClusterSlotsCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(clusterSlotsParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(clusterSlotsParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.([]ClusterSlot) return nil @@ -913,11 +813,6 @@ func NewGeoLocationCmd(q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd { } } -func (cmd *GeoLocationCmd) reset() { - cmd.locations = nil - cmd.err = nil -} - func (cmd *GeoLocationCmd) Val() []GeoLocation { return cmd.locations } @@ -931,12 +826,12 @@ func (cmd *GeoLocationCmd) String() string { } func (cmd *GeoLocationCmd) readReply(cn *pool.Conn) error { - reply, err := cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(newGeoLocationSliceParser(cmd.q)) + if cmd.err != nil { + return cmd.err } - cmd.locations = reply.([]GeoLocation) + cmd.locations = v.([]GeoLocation) return nil } @@ -969,18 +864,13 @@ func (cmd *GeoPosCmd) String() string { return cmdString(cmd, cmd.positions) } -func (cmd *GeoPosCmd) reset() { - cmd.positions = nil - cmd.err = nil -} - func (cmd *GeoPosCmd) readReply(cn *pool.Conn) error { - reply, err := cn.Rd.ReadArrayReply(geoPosSliceParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(geoPosSliceParser) + if cmd.err != nil { + return cmd.err } - cmd.positions = reply.([]*GeoPos) + cmd.positions = v.([]*GeoPos) return nil } @@ -1019,16 +909,11 @@ func (cmd *CommandsInfoCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *CommandsInfoCmd) reset() { - cmd.val = nil - cmd.err = nil -} - func (cmd *CommandsInfoCmd) readReply(cn *pool.Conn) error { - v, err := cn.Rd.ReadArrayReply(commandInfoSliceParser) - if err != nil { - cmd.err = err - return err + var v interface{} + v, cmd.err = cn.Rd.ReadArrayReply(commandInfoSliceParser) + if cmd.err != nil { + return cmd.err } cmd.val = v.(map[string]*CommandInfo) return nil diff --git a/internal/errors.go b/internal/errors.go index e94e123a5..67b29aec3 100644 --- a/internal/errors.go +++ b/internal/errors.go @@ -69,3 +69,7 @@ func IsMovedError(err error) (moved bool, ask bool, addr string) { func IsLoadingError(err error) bool { return strings.HasPrefix(err.Error(), "LOADING") } + +func IsExecAbortError(err error) bool { + return strings.HasPrefix(err.Error(), "EXECABORT") +} diff --git a/internal/proto/reader.go b/internal/proto/reader.go index 7e7284639..9aed3def1 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -70,7 +70,7 @@ func (p *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { switch line[0] { case ErrorReply: - return nil, parseErrorValue(line) + return nil, ParseErrorReply(line) case StatusReply: return parseStatusValue(line) case IntReply: @@ -94,7 +94,7 @@ func (p *Reader) ReadIntReply() (int64, error) { } switch line[0] { case ErrorReply: - return 0, parseErrorValue(line) + return 0, ParseErrorReply(line) case IntReply: return parseIntValue(line) default: @@ -109,7 +109,7 @@ func (p *Reader) ReadBytesReply() ([]byte, error) { } switch line[0] { case ErrorReply: - return nil, parseErrorValue(line) + return nil, ParseErrorReply(line) case StringReply: return p.readBytesValue(line) case StatusReply: @@ -142,7 +142,7 @@ func (p *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { } switch line[0] { case ErrorReply: - return nil, parseErrorValue(line) + return nil, ParseErrorReply(line) case ArrayReply: n, err := parseArrayLen(line) if err != nil { @@ -161,7 +161,7 @@ func (p *Reader) ReadArrayLen() (int64, error) { } switch line[0] { case ErrorReply: - return 0, parseErrorValue(line) + return 0, ParseErrorReply(line) case ArrayReply: return parseArrayLen(line) default: @@ -272,7 +272,7 @@ func isNilReply(b []byte) bool { b[1] == '-' && b[2] == '1' } -func parseErrorValue(line []byte) error { +func ParseErrorReply(line []byte) error { return internal.RedisError(string(line[1:])) } diff --git a/internal/proto/proto.go b/internal/proto/scan.go similarity index 100% rename from internal/proto/proto.go rename to internal/proto/scan.go diff --git a/iterator.go b/iterator.go index e885985b9..4f7536081 100644 --- a/iterator.go +++ b/iterator.go @@ -58,7 +58,6 @@ func (it *ScanIterator) Next() bool { } else { it.ScanCmd._args[2] = it.ScanCmd.cursor } - it.ScanCmd.reset() it.client.process(it.ScanCmd) if it.ScanCmd.Err() != nil { return false diff --git a/pipeline.go b/pipeline.go index ef5510b2d..a0a00e209 100644 --- a/pipeline.go +++ b/pipeline.go @@ -7,6 +7,8 @@ import ( "gopkg.in/redis.v5/internal/pool" ) +type pipelineExecer func([]Cmder) error + // Pipeline implements pipelining as described in // http://redis.io/topics/pipelining. It's safe for concurrent use // by multiple goroutines. @@ -14,7 +16,7 @@ type Pipeline struct { cmdable statefulCmdable - exec func([]Cmder) error + exec pipelineExecer mu sync.Mutex cmds []Cmder diff --git a/pipeline_test.go b/pipeline_test.go index 6c6fb9625..14ba784c6 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -1,17 +1,15 @@ package redis_test import ( - "strconv" - "sync" - "gopkg.in/redis.v5" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("Pipeline", func() { +var _ = Describe("pipelining", func() { var client *redis.Client + var pipe *redis.Pipeline BeforeEach(func() { client = redis.NewClient(redisOptions()) @@ -22,44 +20,7 @@ var _ = Describe("Pipeline", func() { Expect(client.Close()).NotTo(HaveOccurred()) }) - It("should pipeline", func() { - set := client.Set("key2", "hello2", 0) - Expect(set.Err()).NotTo(HaveOccurred()) - Expect(set.Val()).To(Equal("OK")) - - pipeline := client.Pipeline() - set = pipeline.Set("key1", "hello1", 0) - get := pipeline.Get("key2") - incr := pipeline.Incr("key3") - getNil := pipeline.Get("key4") - - cmds, err := pipeline.Exec() - Expect(err).To(Equal(redis.Nil)) - Expect(cmds).To(HaveLen(4)) - Expect(pipeline.Close()).NotTo(HaveOccurred()) - - Expect(set.Err()).NotTo(HaveOccurred()) - Expect(set.Val()).To(Equal("OK")) - - Expect(get.Err()).NotTo(HaveOccurred()) - Expect(get.Val()).To(Equal("hello2")) - - Expect(incr.Err()).NotTo(HaveOccurred()) - Expect(incr.Val()).To(Equal(int64(1))) - - Expect(getNil.Err()).To(Equal(redis.Nil)) - Expect(getNil.Val()).To(Equal("")) - }) - - It("discards queued commands", func() { - pipeline := client.Pipeline() - pipeline.Get("key") - pipeline.Discard() - _, err := pipeline.Exec() - Expect(err).To(MatchError("redis: pipeline is empty")) - }) - - It("should support block style", func() { + It("supports block style", func() { var get *redis.StringCmd cmds, err := client.Pipelined(func(pipe *redis.Pipeline) error { get = pipe.Get("foo") @@ -72,98 +33,47 @@ var _ = Describe("Pipeline", func() { Expect(get.Val()).To(Equal("")) }) - It("should handle vals/err", func() { - pipeline := client.Pipeline() - - get := pipeline.Get("key") - Expect(get.Err()).NotTo(HaveOccurred()) - Expect(get.Val()).To(Equal("")) - Expect(pipeline.Close()).NotTo(HaveOccurred()) - }) - - It("returns an error when there are no commands", func() { - pipeline := client.Pipeline() - _, err := pipeline.Exec() - Expect(err).To(MatchError("redis: pipeline is empty")) - }) - - It("should increment correctly", func() { - const N = 20000 - key := "TestPipelineIncr" - pipeline := client.Pipeline() - for i := 0; i < N; i++ { - pipeline.Incr(key) - } - - cmds, err := pipeline.Exec() - Expect(err).NotTo(HaveOccurred()) - Expect(pipeline.Close()).NotTo(HaveOccurred()) - - Expect(len(cmds)).To(Equal(20000)) - for _, cmd := range cmds { - Expect(cmd.Err()).NotTo(HaveOccurred()) - } - - get := client.Get(key) - Expect(get.Err()).NotTo(HaveOccurred()) - Expect(get.Val()).To(Equal(strconv.Itoa(N))) - }) - - It("should PipelineEcho", func() { - const N = 1000 - - wg := &sync.WaitGroup{} - wg.Add(N) - for i := 0; i < N; i++ { - go func(i int) { - defer GinkgoRecover() - defer wg.Done() - - pipeline := client.Pipeline() + assertPipeline := func() { + It("returns an error when there are no commands", func() { + _, err := pipe.Exec() + Expect(err).To(MatchError("redis: pipeline is empty")) + }) - msg1 := "echo" + strconv.Itoa(i) - msg2 := "echo" + strconv.Itoa(i+1) + It("discards queued commands", func() { + pipe.Get("key") + pipe.Discard() + _, err := pipe.Exec() + Expect(err).To(MatchError("redis: pipeline is empty")) + }) - echo1 := pipeline.Echo(msg1) - echo2 := pipeline.Echo(msg2) + It("handles val/err", func() { + err := client.Set("key", "value", 0).Err() + Expect(err).NotTo(HaveOccurred()) - cmds, err := pipeline.Exec() - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(2)) + get := pipe.Get("key") + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) - Expect(echo1.Err()).NotTo(HaveOccurred()) - Expect(echo1.Val()).To(Equal(msg1)) + val, err := get.Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal("value")) + }) + } - Expect(echo2.Err()).NotTo(HaveOccurred()) - Expect(echo2.Val()).To(Equal(msg2)) + Describe("Pipeline", func() { + BeforeEach(func() { + pipe = client.Pipeline() + }) - Expect(pipeline.Close()).NotTo(HaveOccurred()) - }(i) - } - wg.Wait() + assertPipeline() }) - It("should be thread-safe", func() { - const N = 1000 - - pipeline := client.Pipeline() - var wg sync.WaitGroup - wg.Add(N) - for i := 0; i < N; i++ { - go func() { - defer GinkgoRecover() - - pipeline.Ping() - wg.Done() - }() - } - wg.Wait() - - cmds, err := pipeline.Exec() - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(N)) + Describe("TxPipeline", func() { + BeforeEach(func() { + pipe = client.TxPipeline() + }) - Expect(pipeline.Close()).NotTo(HaveOccurred()) + assertPipeline() }) - }) diff --git a/race_test.go b/race_test.go index 1e44bf03e..2f5e7b31d 100644 --- a/race_test.go +++ b/race_test.go @@ -245,4 +245,35 @@ var _ = Describe("races", func() { Expect(val).To(Equal(int64(C * N))) }) + It("should Pipeline", func() { + perform(C, func(id int) { + pipe := client.Pipeline() + for i := 0; i < N; i++ { + pipe.Echo(fmt.Sprint(i)) + } + + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(N)) + + for i := 0; i < N; i++ { + Expect(cmds[i].(*redis.StringCmd).Val()).To(Equal(fmt.Sprint(i))) + } + }) + }) + + It("should Pipeline", func() { + pipe := client.Pipeline() + perform(N, func(id int) { + pipe.Incr("key") + }) + + cmds, err := pipe.Exec() + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(N)) + + n, err := client.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(N))) + }) }) diff --git a/redis.go b/redis.go index 6fcd5c4a0..7bfd0abf7 100644 --- a/redis.go +++ b/redis.go @@ -7,6 +7,7 @@ import ( "gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal/pool" + "gopkg.in/redis.v5/internal/proto" ) // Redis nil reply, .e.g. when key does not exist. @@ -96,10 +97,6 @@ func (c *baseClient) WrapProcess(fn func(oldProcess func(cmd Cmder) error) func( func (c *baseClient) defaultProcess(cmd Cmder) error { for i := 0; i <= c.opt.MaxRetries; i++ { - if i > 0 { - cmd.reset() - } - cn, _, err := c.conn() if err != nil { cmd.setErr(err) @@ -162,6 +159,129 @@ func (c *baseClient) getAddr() string { return c.opt.Addr } +type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error) + +func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer { + return func(cmds []Cmder) error { + var firstErr error + for i := 0; i <= c.opt.MaxRetries; i++ { + cn, _, err := c.conn() + if err != nil { + setCmdsErr(cmds, err) + return err + } + + canRetry, err := p(cn, cmds) + c.putConn(cn, err, false) + if err == nil { + return nil + } + if firstErr == nil { + firstErr = err + } + if !canRetry || !internal.IsRetryableError(err) { + break + } + } + return firstErr + } +} + +func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { + cn.SetWriteTimeout(c.opt.WriteTimeout) + if err := writeCmd(cn, cmds...); err != nil { + setCmdsErr(cmds, err) + return true, err + } + + // Set read timeout for all commands. + cn.SetReadTimeout(c.opt.ReadTimeout) + return pipelineReadCmds(cn, cmds) +} + +func pipelineReadCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { + for i, cmd := range cmds { + err := cmd.readReply(cn) + if err == nil { + continue + } + if i == 0 { + retry = true + } + if firstErr == nil { + firstErr = err + } + } + return false, firstErr +} + +func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) { + cn.SetWriteTimeout(c.opt.WriteTimeout) + if err := txPipelineWriteMulti(cn, cmds); err != nil { + setCmdsErr(cmds, err) + return true, err + } + + // Set read timeout for all commands. + cn.SetReadTimeout(c.opt.ReadTimeout) + + if err := c.txPipelineReadQueued(cn, cmds); err != nil { + return false, err + } + + _, err := pipelineReadCmds(cn, cmds) + return false, err +} + +func txPipelineWriteMulti(cn *pool.Conn, cmds []Cmder) error { + multiExec := make([]Cmder, 0, len(cmds)+2) + multiExec = append(multiExec, NewStatusCmd("MULTI")) + multiExec = append(multiExec, cmds...) + multiExec = append(multiExec, NewSliceCmd("EXEC")) + return writeCmd(cn, multiExec...) +} + +func (c *baseClient) txPipelineReadQueued(cn *pool.Conn, cmds []Cmder) error { + var firstErr error + + // Parse queued replies. + var statusCmd StatusCmd + if err := statusCmd.readReply(cn); err != nil && firstErr == nil { + firstErr = err + } + + for _, cmd := range cmds { + err := statusCmd.readReply(cn) + if err != nil { + cmd.setErr(err) + if firstErr == nil { + firstErr = err + } + } + } + + // Parse number of replies. + line, err := cn.Rd.ReadLine() + if err != nil { + if err == Nil { + err = TxFailedErr + } + return err + } + + switch line[0] { + case proto.ErrorReply: + return proto.ParseErrorReply(line) + case proto.ArrayReply: + // ok + default: + err := fmt.Errorf("redis: expected '*', but got line %q", line) + return err + } + + return nil +} + //------------------------------------------------------------------------------ // Client is a Redis client representing a pool of zero or more @@ -202,70 +322,30 @@ func (c *Client) PoolStats() *PoolStats { } } +func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { + return c.Pipeline().pipelined(fn) +} + func (c *Client) Pipeline() *Pipeline { pipe := Pipeline{ - exec: c.pipelineExec, + exec: c.pipelineExecer(c.pipelineProcessCmds), } pipe.cmdable.process = pipe.Process pipe.statefulCmdable.process = pipe.Process return &pipe } -func (c *Client) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { - return c.Pipeline().pipelined(fn) -} - -func (c *Client) pipelineExec(cmds []Cmder) error { - var firstErr error - for i := 0; i <= c.opt.MaxRetries; i++ { - if i > 0 { - resetCmds(cmds) - } - - cn, _, err := c.conn() - if err != nil { - setCmdsErr(cmds, err) - return err - } - - retry, err := c.execCmds(cn, cmds) - c.putConn(cn, err, false) - if err == nil { - return nil - } - if firstErr == nil { - firstErr = err - } - if !retry { - break - } - } - return firstErr +func (c *Client) TxPipelined(fn func(*Pipeline) error) ([]Cmder, error) { + return c.TxPipeline().pipelined(fn) } -func (c *Client) execCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) { - cn.SetWriteTimeout(c.opt.WriteTimeout) - if err := writeCmd(cn, cmds...); err != nil { - setCmdsErr(cmds, err) - return true, err - } - - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - - for i, cmd := range cmds { - err := cmd.readReply(cn) - if err == nil { - continue - } - if i == 0 && internal.IsNetworkError(err) { - return true, err - } - if firstErr == nil { - firstErr = err - } +func (c *Client) TxPipeline() *Pipeline { + pipe := Pipeline{ + exec: c.pipelineExecer(c.txPipelineProcessCmds), } - return false, firstErr + pipe.cmdable.process = pipe.Process + pipe.statefulCmdable.process = pipe.Process + return &pipe } func (c *Client) pubSub() *PubSub { diff --git a/ring.go b/ring.go index 11945b4a5..116ca1710 100644 --- a/ring.go +++ b/ring.go @@ -381,10 +381,6 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { var failedCmdsMap map[string][]Cmder for name, cmds := range cmdsMap { - if i > 0 { - resetCmds(cmds) - } - shard, err := c.shardByName(name) if err != nil { setCmdsErr(cmds, err) @@ -403,7 +399,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { continue } - retry, err := shard.Client.execCmds(cn, cmds) + canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) shard.Client.putConn(cn, err, false) if err == nil { continue @@ -411,7 +407,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { if firstErr == nil { firstErr = err } - if retry { + if canRetry && internal.IsRetryableError(err) { if failedCmdsMap == nil { failedCmdsMap = make(map[string][]Cmder) } diff --git a/tx.go b/tx.go index 30beda940..af65fe979 100644 --- a/tx.go +++ b/tx.go @@ -1,11 +1,8 @@ package redis import ( - "fmt" - "gopkg.in/redis.v5/internal" "gopkg.in/redis.v5/internal/pool" - "gopkg.in/redis.v5/internal/proto" ) // Redis transaction failed. @@ -19,8 +16,6 @@ type Tx struct { cmdable statefulCmdable baseClient - - closed bool } var _ Cmdable = (*Tx)(nil) @@ -41,26 +36,20 @@ func (c *Client) Watch(fn func(*Tx) error, keys ...string) error { tx := c.newTx() if len(keys) > 0 { if err := tx.Watch(keys...).Err(); err != nil { - _ = tx.close() + _ = tx.Close() return err } } firstErr := fn(tx) - if err := tx.close(); err != nil && firstErr == nil { + if err := tx.Close(); err != nil && firstErr == nil { firstErr = err } return firstErr } // close closes the transaction, releasing any open resources. -func (c *Tx) close() error { - if c.closed { - return nil - } - c.closed = true - if err := c.Unwatch().Err(); err != nil { - internal.Logf("Unwatch failed: %s", err) - } +func (c *Tx) Close() error { + _ = c.Unwatch().Err() return c.baseClient.Close() } @@ -91,7 +80,7 @@ func (c *Tx) Unwatch(keys ...string) *StatusCmd { func (c *Tx) Pipeline() *Pipeline { pipe := Pipeline{ - exec: c.exec, + exec: c.pipelineExecer(c.txPipelineProcessCmds), } pipe.cmdable.process = pipe.Process pipe.statefulCmdable.process = pipe.Process @@ -110,76 +99,3 @@ func (c *Tx) Pipeline() *Pipeline { func (c *Tx) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) { return c.Pipeline().pipelined(fn) } - -func (c *Tx) exec(cmds []Cmder) error { - if c.closed { - return pool.ErrClosed - } - - cn, _, err := c.conn() - if err != nil { - setCmdsErr(cmds, err) - return err - } - - multiExec := make([]Cmder, 0, len(cmds)+2) - multiExec = append(multiExec, NewStatusCmd("MULTI")) - multiExec = append(multiExec, cmds...) - multiExec = append(multiExec, NewSliceCmd("EXEC")) - - err = c.execCmds(cn, multiExec) - c.putConn(cn, err, false) - return err -} - -func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error { - cn.SetWriteTimeout(c.opt.WriteTimeout) - err := writeCmd(cn, cmds...) - if err != nil { - setCmdsErr(cmds[1:len(cmds)-1], err) - return err - } - - // Set read timeout for all commands. - cn.SetReadTimeout(c.opt.ReadTimeout) - - // Omit last command (EXEC). - cmdsLen := len(cmds) - 1 - - // Parse queued replies. - statusCmd := cmds[0] - for i := 0; i < cmdsLen; i++ { - if err := statusCmd.readReply(cn); err != nil { - setCmdsErr(cmds[1:len(cmds)-1], err) - return err - } - } - - // Parse number of replies. - line, err := cn.Rd.ReadLine() - if err != nil { - if err == Nil { - err = TxFailedErr - } - setCmdsErr(cmds[1:len(cmds)-1], err) - return err - } - if line[0] != proto.ArrayReply { - err := fmt.Errorf("redis: expected '*', but got line %q", line) - setCmdsErr(cmds[1:len(cmds)-1], err) - return err - } - - var firstErr error - - // Parse replies. - // Loop starts from 1 to omit MULTI cmd. - for i := 1; i < cmdsLen; i++ { - cmd := cmds[i] - if err := cmd.readReply(cn); err != nil && firstErr == nil { - firstErr = err - } - } - - return firstErr -}