From 8e8510431d287721017e359fc3ea57c445a188b1 Mon Sep 17 00:00:00 2001 From: monkey92t Date: Wed, 26 May 2021 11:25:18 +0800 Subject: [PATCH] Improve pubsub (#1764) * Improve pubsub Signed-off-by: monkey92t * Extract code to channel struct and tweak API * Move chanSendTimeout to channel * Cleanup health check * Add WithChannelSendTimeout and tweak comments * clear notes Signed-off-by: monkey92t Co-authored-by: Vladimir Mihailenco --- pubsub.go | 211 ++++++++++++++++++++++++++++--------------------- pubsub_test.go | 19 +++++ 2 files changed, 141 insertions(+), 89 deletions(-) diff --git a/pubsub.go b/pubsub.go index c56270b44..68cea7101 100644 --- a/pubsub.go +++ b/pubsub.go @@ -2,7 +2,6 @@ package redis import ( "context" - "errors" "fmt" "strings" "sync" @@ -13,13 +12,6 @@ import ( "github.com/go-redis/redis/v8/internal/proto" ) -const ( - pingTimeout = time.Second - chanSendTimeout = time.Minute -) - -var errPingTimeout = errors.New("redis: ping timeout") - // PubSub implements Pub/Sub commands as described in // http://redis.io/topics/pubsub. Message receiving is NOT safe // for concurrent use by multiple goroutines. @@ -43,9 +35,12 @@ type PubSub struct { cmd *Cmd chOnce sync.Once - msgCh chan *Message - allCh chan interface{} - ping chan struct{} + msgCh *channel + allCh *channel +} + +func (c *PubSub) init() { + c.exit = make(chan struct{}) } func (c *PubSub) String() string { @@ -54,10 +49,6 @@ func (c *PubSub) String() string { return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) } -func (c *PubSub) init() { - c.exit = make(chan struct{}) -} - func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { c.mu.Lock() cn, err := c.conn(ctx, nil) @@ -418,6 +409,15 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { } } +func (c *PubSub) getContext() context.Context { + if c.cmd != nil { + return c.cmd.ctx + } + return context.Background() +} + +//------------------------------------------------------------------------------ + // Channel returns a Go channel for concurrently receiving messages. // The channel is closed together with the PubSub. If the Go channel // is blocked full for 30 seconds the message is dropped. @@ -425,26 +425,24 @@ func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { // // go-redis periodically sends ping messages to test connection health // and re-subscribes if ping can not not received for 30 seconds. -func (c *PubSub) Channel() <-chan *Message { - return c.ChannelSize(100) -} - -// ChannelSize is like Channel, but creates a Go channel -// with specified buffer size. -func (c *PubSub) ChannelSize(size int) <-chan *Message { +func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message { c.chOnce.Do(func() { - c.initPing() - c.initMsgChan(size) + c.msgCh = newChannel(c, opts...) + c.msgCh.initMsgChan() }) if c.msgCh == nil { err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") panic(err) } - if cap(c.msgCh) != size { - err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") - panic(err) - } - return c.msgCh + return c.msgCh.msgCh +} + +// ChannelSize is like Channel, but creates a Go channel +// with specified buffer size. +// +// Deprecated: use Channel(WithChannelSize(size)), remove in v9. +func (c *PubSub) ChannelSize(size int) <-chan *Message { + return c.Channel(WithChannelSize(size)) } // ChannelWithSubscriptions is like Channel, but message type can be either @@ -452,59 +450,101 @@ func (c *PubSub) ChannelSize(size int) <-chan *Message { // reconnections. // // ChannelWithSubscriptions can not be used together with Channel or ChannelSize. -func (c *PubSub) ChannelWithSubscriptions(ctx context.Context, size int) <-chan interface{} { +func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { c.chOnce.Do(func() { - c.initPing() - c.initAllChan(size) + c.allCh = newChannel(c, WithChannelSize(size)) + c.allCh.initAllChan() }) if c.allCh == nil { err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") panic(err) } - if cap(c.allCh) != size { - err := fmt.Errorf("redis: PubSub.Channel size can not be changed once created") - panic(err) + return c.allCh.allCh +} + +type ChannelOption func(c *channel) + +// WithChannelSize specifies the Go chan size that is used to buffer incoming messages. +// +// The default is 100 messages. +func WithChannelSize(size int) ChannelOption { + return func(c *channel) { + c.chanSize = size } - return c.allCh } -func (c *PubSub) getContext() context.Context { - if c.cmd != nil { - return c.cmd.ctx +// WithChannelHealthCheckInterval specifies the health check interval. +// PubSub will ping Redis Server if it does not receive any messages within the interval. +// To disable health check, use zero interval. +// +// The default is 3 seconds. +func WithChannelHealthCheckInterval(d time.Duration) ChannelOption { + return func(c *channel) { + c.checkInterval = d } - return context.Background() } -func (c *PubSub) initPing() { +// WithChannelSendTimeout specifies that channel send timeout after which +// the message is dropped. +// +// The default is 60 seconds. +func WithChannelSendTimeout(d time.Duration) ChannelOption { + return func(c *channel) { + c.chanSendTimeout = d + } +} + +type channel struct { + pubSub *PubSub + + msgCh chan *Message + allCh chan interface{} + ping chan struct{} + + chanSize int + chanSendTimeout time.Duration + checkInterval time.Duration +} + +func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel { + c := &channel{ + pubSub: pubSub, + + chanSize: 100, + chanSendTimeout: time.Minute, + checkInterval: 3 * time.Second, + } + for _, opt := range opts { + opt(c) + } + if c.checkInterval > 0 { + c.initHealthCheck() + } + return c +} + +func (c *channel) initHealthCheck() { ctx := context.TODO() c.ping = make(chan struct{}, 1) + go func() { timer := time.NewTimer(time.Minute) timer.Stop() - healthy := true for { - timer.Reset(pingTimeout) + timer.Reset(c.checkInterval) select { case <-c.ping: - healthy = true if !timer.Stop() { <-timer.C } case <-timer.C: - pingErr := c.Ping(ctx) - if healthy { - healthy = false - } else { - if pingErr == nil { - pingErr = errPingTimeout - } - c.mu.Lock() - c.reconnect(ctx, pingErr) - healthy = true - c.mu.Unlock() + if pingErr := c.pubSub.Ping(ctx); pingErr != nil { + c.pubSub.mu.Lock() + c.pubSub.reconnect(ctx, pingErr) + c.pubSub.mu.Unlock() } - case <-c.exit: + case <-c.pubSub.exit: return } } @@ -512,16 +552,17 @@ func (c *PubSub) initPing() { } // initMsgChan must be in sync with initAllChan. -func (c *PubSub) initMsgChan(size int) { +func (c *channel) initMsgChan() { ctx := context.TODO() - c.msgCh = make(chan *Message, size) + c.msgCh = make(chan *Message, c.chanSize) + go func() { timer := time.NewTimer(time.Minute) timer.Stop() var errCount int for { - msg, err := c.Receive(ctx) + msg, err := c.pubSub.Receive(ctx) if err != nil { if err == pool.ErrClosed { close(c.msgCh) @@ -548,7 +589,7 @@ func (c *PubSub) initMsgChan(size int) { case *Pong: // Ignore. case *Message: - timer.Reset(chanSendTimeout) + timer.Reset(c.chanSendTimeout) select { case c.msgCh <- msg: if !timer.Stop() { @@ -556,30 +597,28 @@ func (c *PubSub) initMsgChan(size int) { } case <-timer.C: internal.Logger.Printf( - c.getContext(), - "redis: %s channel is full for %s (message is dropped)", - c, - chanSendTimeout, - ) + ctx, "redis: %s channel is full for %s (message is dropped)", + c, c.chanSendTimeout) } default: - internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) + internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) } } }() } // initAllChan must be in sync with initMsgChan. -func (c *PubSub) initAllChan(size int) { +func (c *channel) initAllChan() { ctx := context.TODO() - c.allCh = make(chan interface{}, size) + c.allCh = make(chan interface{}, c.chanSize) + go func() { - timer := time.NewTimer(pingTimeout) + timer := time.NewTimer(time.Minute) timer.Stop() var errCount int for { - msg, err := c.Receive(ctx) + msg, err := c.pubSub.Receive(ctx) if err != nil { if err == pool.ErrClosed { close(c.allCh) @@ -601,29 +640,23 @@ func (c *PubSub) initAllChan(size int) { } switch msg := msg.(type) { - case *Subscription: - c.sendMessage(msg, timer) case *Pong: // Ignore. - case *Message: - c.sendMessage(msg, timer) + case *Subscription, *Message: + timer.Reset(c.chanSendTimeout) + select { + case c.allCh <- msg: + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + internal.Logger.Printf( + ctx, "redis: %s channel is full for %s (message is dropped)", + c, c.chanSendTimeout) + } default: - internal.Logger.Printf(c.getContext(), "redis: unknown message type: %T", msg) + internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) } } }() } - -func (c *PubSub) sendMessage(msg interface{}, timer *time.Timer) { - timer.Reset(pingTimeout) - select { - case c.allCh <- msg: - if !timer.Stop() { - <-timer.C - } - case <-timer.C: - internal.Logger.Printf( - c.getContext(), - "redis: %s channel is full for %s (message is dropped)", c, pingTimeout) - } -} diff --git a/pubsub_test.go b/pubsub_test.go index d32d5e0b1..3e0bf9b0a 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -473,4 +473,23 @@ var _ = Describe("PubSub", func() { Fail("timeout") } }) + + It("should ChannelMessage", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + ch := pubsub.Channel( + redis.WithChannelSize(10), + redis.WithChannelHealthCheckInterval(time.Second), + ) + + text := "test channel message" + err := client.Publish(ctx, "mychannel", text).Err() + Expect(err).NotTo(HaveOccurred()) + + var msg *redis.Message + Eventually(ch).Should(Receive(&msg)) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal(text)) + }) })