Skip to content

Commit

Permalink
Merge pull request samuel#139 from vespian/vespian/reauth-on-reconnect
Browse files Browse the repository at this point in the history
Re-submit authentication data on reconnect
  • Loading branch information
samuel authored Sep 2, 2016
2 parents d8d71fb + cc4edb7 commit 87e1bca
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 49 deletions.
201 changes: 152 additions & 49 deletions zk/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ type Logger interface {
Printf(string, ...interface{})
}

type authCreds struct {
scheme string
auth []byte
}

type Conn struct {
lastZxid int64
sessionID int64
Expand All @@ -81,16 +86,22 @@ type Conn struct {
recvTimeout time.Duration
connectTimeout time.Duration

creds []authCreds
credsMu sync.Mutex // protects server

sendChan chan *request
requests map[int32]*request // Xid -> pending request
requestsLock sync.Mutex
watchers map[watchPathType][]chan Event
watchersLock sync.Mutex
closeChan chan struct{} // channel to tell send loop stop

// Debug (used by unit tests)
reconnectDelay time.Duration

logger Logger

buf []byte
}

// connOption represents a connection option.
Expand Down Expand Up @@ -186,6 +197,7 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti
watchers: make(map[watchPathType][]chan Event),
passwd: emptyPassword,
logger: DefaultLogger,
buf: make([]byte, bufferSize),

// Debug
reconnectDelay: 0,
Expand Down Expand Up @@ -317,6 +329,65 @@ func (c *Conn) connect() error {
}
}

func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) {
c.credsMu.Lock()
defer c.credsMu.Unlock()

defer close(reauthReadyChan)

c.logger.Printf("Re-submitting `%d` credentials after reconnect",
len(c.creds))

for _, cred := range c.creds {
resChan, err := c.sendRequest(
opSetAuth,
&setAuthRequest{Type: 0,
Scheme: cred.scheme,
Auth: cred.auth,
},
&setAuthResponse{},
nil)

if err != nil {
c.logger.Printf("Call to sendRequest failed during credential resubmit: %s", err)
// FIXME(prozlach): lets ignore errors for now
continue
}

res := <-resChan
if res.err != nil {
c.logger.Printf("Credential re-submit failed: %s", res.err)
// FIXME(prozlach): lets ignore errors for now
continue
}
}
}

func (c *Conn) sendRequest(
opcode int32,
req interface{},
res interface{},
recvFunc func(*request, *responseHeader, error),
) (
<-chan response,
error,
) {
rq := &request{
xid: c.nextXid(),
opcode: opcode,
pkt: req,
recvStruct: res,
recvChan: make(chan response, 1),
recvFunc: recvFunc,
}

if err := c.sendData(rq); err != nil {
return nil, err
}

return rq.recvChan, nil
}

func (c *Conn) loop() {
for {
if err := c.connect(); err != nil {
Expand All @@ -334,13 +405,15 @@ func (c *Conn) loop() {
c.conn.Close()
case err == nil:
c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs)
c.hostProvider.Connected() // mark success
closeChan := make(chan struct{}) // channel to tell send loop stop
var wg sync.WaitGroup
c.hostProvider.Connected() // mark success
c.closeChan = make(chan struct{}) // channel to tell send loop stop
reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted

var wg sync.WaitGroup
wg.Add(1)
go func() {
err := c.sendLoop(c.conn, closeChan)
<-reauthChan
err := c.sendLoop()
c.logger.Printf("Send loop terminated: err=%v", err)
c.conn.Close() // causes recv loop to EOF/exit
wg.Done()
Expand All @@ -353,10 +426,12 @@ func (c *Conn) loop() {
if err == nil {
panic("zk: recvLoop should never return nil error")
}
close(closeChan) // tell send loop to exit
close(c.closeChan) // tell send loop to exit
wg.Done()
}()

c.resendZkAuth(reauthChan)

c.sendSetWatches()
wg.Wait()
}
Expand Down Expand Up @@ -528,66 +603,73 @@ func (c *Conn) authenticate() error {
return nil
}

func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan struct{}) error {
pingTicker := time.NewTicker(c.pingInterval)
defer pingTicker.Stop()
func (c *Conn) sendData(req *request) error {
header := &requestHeader{req.xid, req.opcode}
n, err := encodePacket(c.buf[4:], header)
if err != nil {
req.recvChan <- response{-1, err}
return nil
}

buf := make([]byte, bufferSize)
for {
select {
case req := <-c.sendChan:
header := &requestHeader{req.xid, req.opcode}
n, err := encodePacket(buf[4:], header)
if err != nil {
req.recvChan <- response{-1, err}
continue
}
n2, err := encodePacket(c.buf[4+n:], req.pkt)
if err != nil {
req.recvChan <- response{-1, err}
return nil
}

n2, err := encodePacket(buf[4+n:], req.pkt)
if err != nil {
req.recvChan <- response{-1, err}
continue
}
n += n2

n += n2
binary.BigEndian.PutUint32(c.buf[:4], uint32(n))

binary.BigEndian.PutUint32(buf[:4], uint32(n))
c.requestsLock.Lock()
select {
case <-c.closeChan:
req.recvChan <- response{-1, ErrConnectionClosed}
c.requestsLock.Unlock()
return ErrConnectionClosed
default:
}
c.requests[req.xid] = req
c.requestsLock.Unlock()

c.requestsLock.Lock()
select {
case <-closeChan:
req.recvChan <- response{-1, ErrConnectionClosed}
c.requestsLock.Unlock()
return ErrConnectionClosed
default:
}
c.requests[req.xid] = req
c.requestsLock.Unlock()
c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = c.conn.Write(c.buf[:n+4])
c.conn.SetWriteDeadline(time.Time{})
if err != nil {
req.recvChan <- response{-1, err}
c.conn.Close()
return err
}

conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = conn.Write(buf[:n+4])
conn.SetWriteDeadline(time.Time{})
if err != nil {
req.recvChan <- response{-1, err}
conn.Close()
return nil
}

func (c *Conn) sendLoop() error {
pingTicker := time.NewTicker(c.pingInterval)
defer pingTicker.Stop()

for {
select {
case req := <-c.sendChan:
if err := c.sendData(req); err != nil {
return err
}
case <-pingTicker.C:
n, err := encodePacket(buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
if err != nil {
panic("zk: opPing should never fail to serialize")
}

binary.BigEndian.PutUint32(buf[:4], uint32(n))
binary.BigEndian.PutUint32(c.buf[:4], uint32(n))

conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = conn.Write(buf[:n+4])
conn.SetWriteDeadline(time.Time{})
c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
_, err = c.conn.Write(c.buf[:n+4])
c.conn.SetWriteDeadline(time.Time{})
if err != nil {
conn.Close()
c.conn.Close()
return err
}
case <-closeChan:
case <-c.closeChan:
return nil
}
}
Expand Down Expand Up @@ -724,7 +806,28 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc

func (c *Conn) AddAuth(scheme string, auth []byte) error {
_, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
return err

if err != nil {
return err
}

// Remember authdata so that it can be re-submitted on reconnect
//
// FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2"
// as two different entries, which will be re-submitted on reconnet. Some
// research is needed on how ZK treats these cases and
// then maybe switch to something like "map[username] = password" to allow
// only single password for given user with users being unique.
obj := authCreds{
scheme: scheme,
auth: auth,
}

c.credsMu.Lock()
c.creds = append(c.creds, obj)
c.credsMu.Unlock()

return nil
}

func (c *Conn) Children(path string) ([]string, *Stat, error) {
Expand Down
22 changes: 22 additions & 0 deletions zk/server_help.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,25 @@ func (tc *TestCluster) StopServer(server string) {
}
panic(fmt.Sprintf("Unknown server: %s", server))
}

func (tc *TestCluster) StartAllServers() error {
for _, s := range tc.Servers {
if err := s.Srv.Start(); err != nil {
return fmt.Errorf(
"Failed to start server listening on port `%d` : %+v", s.Port, err)
}
}

return nil
}

func (tc *TestCluster) StopAllServers() error {
for _, s := range tc.Servers {
if err := s.Srv.Stop(); err != nil {
return fmt.Errorf(
"Failed to stop server listening on port `%d` : %+v", s.Port, err)
}
}

return nil
}
51 changes: 51 additions & 0 deletions zk/zk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,57 @@ func TestMulti(t *testing.T) {
}
}

func TestIfAuthdataSurvivesReconnect(t *testing.T) {
// This test case ensures authentication data is being resubmited after
// reconnect.
testNode := "/auth-testnode"

ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "})
if err != nil {
t.Fatal(err)
}

zk, _, err := ts.ConnectAll()
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
defer zk.Close()

acl := DigestACL(PermAll, "userfoo", "passbar")

_, err = zk.Create(testNode, []byte("Some very secret content"), 0, acl)
if err != nil && err != ErrNodeExists {
t.Fatalf("Failed to create test node : %+v", err)
}

_, _, err = zk.Get(testNode)
if err == nil || err != ErrNoAuth {
var msg string

if err == nil {
msg = "Fetching data without auth should have resulted in an error"
} else {
msg = fmt.Sprintf("Expecting ErrNoAuth, got `%+v` instead", err)
}
t.Fatalf(msg)
}

zk.AddAuth("digest", []byte("userfoo:passbar"))

_, _, err = zk.Get(testNode)
if err != nil {
t.Fatalf("Fetching data with auth failed: %+v", err)
}

ts.StopAllServers()
ts.StartAllServers()

_, _, err = zk.Get(testNode)
if err != nil {
t.Fatalf("Fetching data after reconnect failed: %+v", err)
}
}

func TestMultiFailures(t *testing.T) {
// This test case ensures that we return the errors associated with each
// opeThis in the event a call to Multi() fails.
Expand Down

0 comments on commit 87e1bca

Please sign in to comment.