Skip to content

Commit

Permalink
Improve pubsub example
Browse files Browse the repository at this point in the history
  • Loading branch information
garyburd committed Nov 27, 2017
1 parent 4a7d9db commit 1e086fa
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 87 deletions.
164 changes: 164 additions & 0 deletions redis/pubsub_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright 2012 Gary Burd
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

package redis_test

import (
"context"
"fmt"
"time"

"github.com/garyburd/redigo/redis"
)

// listenPubSubChannels listens for messages on Redis pubsub channels. The
// onStart function is called after the channels are subscribed. The onMessage
// function is called for each message.
func listenPubSubChannels(ctx context.Context, redisServerAddr string,
onStart func() error,
onMessage func(channel string, data []byte) error,
channels ...string) error {

// A ping is set to the server with this period to test for the health of
// the connection and server.
const healthCheckPeriod = time.Minute

c, err := redis.Dial("tcp", redisServerAddr,
// Read timeout on server should be greater than ping period.
redis.DialReadTimeout(healthCheckPeriod+10*time.Second),
redis.DialWriteTimeout(10*time.Second))
if err != nil {
return err
}
defer c.Close()

psc := redis.PubSubConn{Conn: c}

if err := psc.Subscribe(redis.Args{}.AddFlat(channels)...); err != nil {
return err
}

done := make(chan error, 1)

// Start a goroutine to receive notifications from the server.
go func() {
for {
switch n := psc.Receive().(type) {
case error:
done <- n
return
case redis.Message:
if err := onMessage(n.Channel, n.Data); err != nil {
done <- err
return
}
case redis.Subscription:
switch n.Count {
case len(channels):
// Notify application when all channels are subscribed.
if err := onStart(); err != nil {
done <- err
return
}
case 0:
// Return from the goroutine when all channels are unsubscribed.
done <- nil
return
}
}
}
}()

ticker := time.NewTicker(healthCheckPeriod)
defer ticker.Stop()
loop:
for err == nil {
select {
case <-ticker.C:
// Send ping to test health of connection and server. If
// corresponding pong is not received, then receive on the
// connection will timeout and the receive goroutine will exit.
if err = psc.Ping(""); err != nil {
break loop
}
case <-ctx.Done():
break loop
case err := <-done:
// Return error from the receive goroutine.
return err
}
}

// Signal the receiving goroutine to exit by unsubscribing from all channels.
psc.Unsubscribe()

// Wait for goroutine to complete.
return <-done
}

func publish() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()

c.Do("PUBLISH", "c1", "hello")
c.Do("PUBLISH", "c2", "world")
c.Do("PUBLISH", "c1", "goodbye")
}

// This example shows how receive pubsub notifications with cancelation and
// health checks.
func ExamplePubSubConn() {
redisServerAddr, err := serverAddr()
if err != nil {
fmt.Println(err)
return
}

ctx, cancel := context.WithCancel(context.Background())

err = listenPubSubChannels(ctx,
redisServerAddr,
func() error {
// The start callback is a good place to backfill missed
// notifications. For the purpose of this example, a goroutine is
// started to send notifications.
go publish()
return nil
},
func(channel string, message []byte) error {
fmt.Printf("channel: %s, message: %s\n", channel, message)

// For the purpose of this example, cancel the listener's context
// after receiving last message sent by publish().
if string(message) == "goodbye" {
cancel()
}
return nil
},
"c1", "c2")

if err != nil {
fmt.Println(err)
return
}

// Output:
// channel: c1, message: hello
// channel: c2, message: world
// channel: c1, message: goodbye
}
81 changes: 0 additions & 81 deletions redis/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,93 +15,12 @@
package redis_test

import (
"fmt"
"reflect"
"sync"
"testing"

"github.com/garyburd/redigo/redis"
)

func publish(channel, value interface{}) {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
c.Do("PUBLISH", channel, value)
}

// Applications can receive pushed messages from one goroutine and manage subscriptions from another goroutine.
func ExamplePubSubConn() {
c, err := dial()
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
var wg sync.WaitGroup
wg.Add(2)

psc := redis.PubSubConn{Conn: c}

// This goroutine receives and prints pushed notifications from the server.
// The goroutine exits when the connection is unsubscribed from all
// channels or there is an error.
go func() {
defer wg.Done()
for {
switch n := psc.Receive().(type) {
case redis.Message:
fmt.Printf("Message: %s %s\n", n.Channel, n.Data)
case redis.PMessage:
fmt.Printf("PMessage: %s %s %s\n", n.Pattern, n.Channel, n.Data)
case redis.Subscription:
fmt.Printf("Subscription: %s %s %d\n", n.Kind, n.Channel, n.Count)
if n.Count == 0 {
return
}
case error:
fmt.Printf("error: %v\n", n)
return
}
}
}()

// This goroutine manages subscriptions for the connection.
go func() {
defer wg.Done()

psc.Subscribe("example")
psc.PSubscribe("p*")

// The following function calls publish a message using another
// connection to the Redis server.
publish("example", "hello")
publish("example", "world")
publish("pexample", "foo")
publish("pexample", "bar")

// Unsubscribe from all connections. This will cause the receiving
// goroutine to exit.
psc.Unsubscribe()
psc.PUnsubscribe()
}()

wg.Wait()

// Output:
// Subscription: subscribe example 1
// Subscription: psubscribe p* 2
// Message: example hello
// Message: example world
// PMessage: p* pexample foo
// PMessage: p* pexample bar
// Subscription: unsubscribe example 1
// Subscription: punsubscribe p* 0
}

func expectPushed(t *testing.T, c redis.PubSubConn, message string, expected interface{}) {
actual := c.Receive()
if !reflect.DeepEqual(actual, expected) {
Expand Down
5 changes: 5 additions & 0 deletions redis/reply_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ func dial() (redis.Conn, error) {
return redis.DialDefaultServer()
}

// serverAddr wraps DefaultServerAddr() with a more suitable function name for examples.
func serverAddr() (string, error) {
return redis.DefaultServerAddr()
}

func ExampleBool() {
c, err := dial()
if err != nil {
Expand Down
15 changes: 9 additions & 6 deletions redis/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,29 +127,32 @@ func stopDefaultServer() {
}
}

// startDefaultServer starts the default server if not already running.
func startDefaultServer() error {
// DefaultServerAddr starts the test server if not already started and returns
// the address of that server.
func DefaultServerAddr() (string, error) {
defaultServerMu.Lock()
defer defaultServerMu.Unlock()
addr := fmt.Sprintf("%v:%d", *serverAddress, *serverBasePort)
if defaultServer != nil || defaultServerErr != nil {
return defaultServerErr
return addr, defaultServerErr
}
defaultServer, defaultServerErr = NewServer(
"default",
"--port", strconv.Itoa(*serverBasePort),
"--bind", *serverAddress,
"--save", "",
"--appendonly", "no")
return defaultServerErr
return addr, defaultServerErr
}

// DialDefaultServer starts the test server if not already started and dials a
// connection to the server.
func DialDefaultServer() (Conn, error) {
if err := startDefaultServer(); err != nil {
addr, err := DefaultServerAddr()
if err != nil {
return nil, err
}
c, err := Dial("tcp", fmt.Sprintf("%v:%d", *serverAddress, *serverBasePort), DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
c, err := Dial("tcp", addr, DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 1e086fa

Please sign in to comment.