Skip to content

Commit

Permalink
Implement HostProvider and connection options
Browse files Browse the repository at this point in the history
- Added variadic options to zk.Connect
- Initial implementation of HostProvider, and DNSHostProvider, the
  default HostProvider.
  • Loading branch information
zellyn committed Oct 28, 2015
1 parent 913027e commit 57af1c8
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 47 deletions.
127 changes: 80 additions & 47 deletions zk/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,15 @@ type Conn struct {
sessionTimeoutMs int32 // session timeout in milliseconds
passwd []byte

dialer Dialer
servers []string
serverIndex int // remember last server that was tried during connect to round-robin attempts to servers
lastServerIndex int // index of the last server that was successfully connected to and authenticated with
conn net.Conn
eventChan chan Event
shouldQuit chan struct{}
pingInterval time.Duration
recvTimeout time.Duration
connectTimeout time.Duration
dialer Dialer
hostProvider HostProvider
server string // remember the address/port of the current server
conn net.Conn
eventChan chan Event
shouldQuit chan struct{}
pingInterval time.Duration
recvTimeout time.Duration
connectTimeout time.Duration

sendChan chan *request
requests map[int32]*request // Xid -> pending request
Expand All @@ -92,6 +91,9 @@ type Conn struct {
logger Logger
}

// connOption represents a connection option.
type connOption func(c *Conn)

type request struct {
xid int32
opcode int32
Expand Down Expand Up @@ -122,20 +124,35 @@ type Event struct {
Server string // For connection events
}

// Connect establishes a new connection to a pool of zookeeper servers
// using the default net.Dialer. See ConnectWithDialer for further
// information about session timeout.
func Connect(servers []string, sessionTimeout time.Duration) (*Conn, <-chan Event, error) {
return ConnectWithDialer(servers, sessionTimeout, nil)
// HostProvider is used to represent a set of hosts a ZooKeeper client should connect to.
// It is an analog of the Java equivalent:
// http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup
type HostProvider interface {
// Init is called first, with the servers specified in the connection string.
Init(servers []string) error
// Len returns the number of servers.
Len() int
// Next returns the next server to connect to. retryStart will be true if we've looped through
// all known servers without Connected() being called.
Next() (server string, retryStart bool)
// Notify the HostProvider of a successful connection.
Connected()
}

// ConnectWithDialer establishes a new connection to a pool of zookeeper servers
// using a custom Dialer. See Connect for further information about session timeout.
// This method is deprecated and provided for compatibility: use the WithDialer option instead.
func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
return Connect(servers, sessionTimeout, WithDialer(dialer))
}

// ConnectWithDialer establishes a new connection to a pool of zookeeper
// Connect establishes a new connection to a pool of zookeeper
// servers. The provided session timeout sets the amount of time for which
// a session is considered valid after losing connection to a server. Within
// the session timeout it's possible to reestablish a connection to a different
// server and keep the same session. This is means any ephemeral nodes and
// watches are maintained.
func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
if len(servers) == 0 {
return nil, nil, errors.New("zk: server list must not be empty")
}
Expand All @@ -154,28 +171,33 @@ func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Di
stringShuffle(srvs)

ec := make(chan Event, eventChanSize)
if dialer == nil {
dialer = net.DialTimeout
}
conn := Conn{
dialer: dialer,
servers: srvs,
serverIndex: 0,
lastServerIndex: -1,
conn: nil,
state: StateDisconnected,
eventChan: ec,
shouldQuit: make(chan struct{}),
connectTimeout: 1 * time.Second,
sendChan: make(chan *request, sendChanSize),
requests: make(map[int32]*request),
watchers: make(map[watchPathType][]chan Event),
passwd: emptyPassword,
logger: DefaultLogger,
conn := &Conn{
dialer: net.DialTimeout,
hostProvider: &DNSHostProvider{},
conn: nil,
state: StateDisconnected,
eventChan: ec,
shouldQuit: make(chan struct{}),
connectTimeout: 1 * time.Second,
sendChan: make(chan *request, sendChanSize),
requests: make(map[int32]*request),
watchers: make(map[watchPathType][]chan Event),
passwd: emptyPassword,
logger: DefaultLogger,

// Debug
reconnectDelay: 0,
}

// Set provided options.
for _, option := range options {
option(conn)
}

if err := conn.hostProvider.Init(srvs); err != nil {
return nil, nil, err
}

conn.setTimeouts(int32(sessionTimeout / time.Millisecond))

go func() {
Expand All @@ -184,7 +206,21 @@ func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Di
conn.invalidateWatches(ErrClosing)
close(conn.eventChan)
}()
return &conn, ec, nil
return conn, ec, nil
}

// WithDialer returns a connection option specifying a non-default Dialer.
func WithDialer(dialer Dialer) connOption {
return func(c *Conn) {
c.dialer = dialer
}
}

// WithHostProvider returns a connection option specifying a non-default HostProvider.
func WithHostProvider(hostProvider HostProvider) connOption {
return func(c *Conn) {
c.hostProvider = hostProvider
}
}

func (c *Conn) Close() {
Expand Down Expand Up @@ -217,17 +253,18 @@ func (c *Conn) setTimeouts(sessionTimeoutMs int32) {
func (c *Conn) setState(state State) {
atomic.StoreInt32((*int32)(&c.state), int32(state))
select {
case c.eventChan <- Event{Type: EventSession, State: state, Server: c.servers[c.serverIndex]}:
case c.eventChan <- Event{Type: EventSession, State: state, Server: c.server}:
default:
// panic("zk: event channel full - it must be monitored and never allowed to be full")
}
}

func (c *Conn) connect() error {
var retryStart bool
for {
c.serverIndex = (c.serverIndex + 1) % len(c.servers)
c.server, retryStart = c.hostProvider.Next()
c.setState(StateConnecting)
if c.serverIndex == c.lastServerIndex {
if retryStart {
c.flushUnsentRequests(ErrNoServer)
select {
case <-time.After(time.Second):
Expand All @@ -237,20 +274,17 @@ func (c *Conn) connect() error {
c.flushUnsentRequests(ErrClosing)
return ErrClosing
}
} else if c.lastServerIndex < 0 {
// lastServerIndex defaults to -1 to avoid a delay on the initial connect
c.lastServerIndex = 0
}

zkConn, err := c.dialer("tcp", c.servers[c.serverIndex], c.connectTimeout)
zkConn, err := c.dialer("tcp", c.server, c.connectTimeout)
if err == nil {
c.conn = zkConn
c.setState(StateConnected)
c.logger.Printf("Connected to %s", c.servers[c.serverIndex])
c.logger.Printf("Connected to %s", c.server)
return nil
}

c.logger.Printf("Failed to connect to %s: %+v", c.servers[c.serverIndex], err)
c.logger.Printf("Failed to connect to %s: %+v", c.server, err)
}
}

Expand All @@ -271,7 +305,7 @@ func (c *Conn) loop() {
c.conn.Close()
case err == nil:
c.logger.Printf("Authenticated: id=%d, timeout=%d", c.sessionID, c.sessionTimeoutMs)
c.lastServerIndex = c.serverIndex
c.hostProvider.Connected() // mark success
closeChan := make(chan struct{}) // channel to tell send loop stop
var wg sync.WaitGroup

Expand Down Expand Up @@ -810,7 +844,6 @@ func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {
_, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil)
return res.Acl, &res.Stat, err
}

func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
res := &setAclResponse{}
_, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil)
Expand Down
87 changes: 87 additions & 0 deletions zk/dnshostprovider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package zk

import (
"fmt"
"net"
"sync"
)

// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
curr int
last int
lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing.
}

// Init is called first, with the servers specified in the connection
// string. It uses DNS to look up addresses for each server, then
// shuffles them all together.
func (hp *DNSHostProvider) Init(servers []string) error {
hp.mu.Lock()
defer hp.mu.Unlock()

lookupHost := hp.lookupHost
if lookupHost == nil {
lookupHost = net.LookupHost
}

found := []string{}
for _, server := range servers {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
}
addrs, err := lookupHost(host)
if err != nil {
return err
}
for _, addr := range addrs {
found = append(found, net.JoinHostPort(addr, port))
}
}

if len(found) == 0 {
return fmt.Errorf("No hosts found for addresses %q", servers)
}

// Randomize the order of the servers to avoid creating hotspots
stringShuffle(found)

hp.servers = found
hp.curr = -1
hp.last = -1

return nil
}

func (hp *DNSHostProvider) Len() int {
hp.mu.Lock()
defer hp.mu.Unlock()
return len(hp.servers)
}

// Next returns the next server to connect to. retryStart will be true
// if we've looped through all known servers without Connected() being
// called.
func (hp *DNSHostProvider) Next() (server string, retryStart bool) {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.curr = (hp.curr + 1) % len(hp.servers)
retryStart = hp.curr == hp.last
if hp.last == -1 {
hp.last = 0
}
return hp.servers[hp.curr], retryStart
}

// Notify the HostProvider of a successful connection.
func (hp *DNSHostProvider) Connected() {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.last = hp.curr
}
Loading

0 comments on commit 57af1c8

Please sign in to comment.