From 3449f0afde895e04c5e4ce2c7242b51dd9b7c780 Mon Sep 17 00:00:00 2001 From: Pawel Rozlach Date: Tue, 30 Aug 2016 16:31:49 +0200 Subject: [PATCH 1/4] Re-submit authentication data on reconnect In case when reconnection occurs, it is necessary to re-authenticate with ZK because it "forgets" the authentication data. --- zk/conn.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/zk/conn.go b/zk/conn.go index 43d8f687..06ed5a2e 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -61,6 +61,11 @@ type Logger interface { Printf(string, ...interface{}) } +type authCreds struct { + scheme string + auth []byte +} + type Conn struct { lastZxid int64 sessionID int64 @@ -81,6 +86,9 @@ type Conn struct { recvTimeout time.Duration connectTimeout time.Duration + creds []authCreds + credsMu sync.Mutex // protects server + sendChan chan *request requests map[int32]*request // Xid -> pending request requestsLock sync.Mutex @@ -357,6 +365,20 @@ func (c *Conn) loop() { wg.Done() }() + c.credsMu.Lock() + if len(c.creds) > 0 { + c.logger.Printf("Re-submitting %d credentials after reconnect", + len(c.creds)) + for _, cred := range c.creds { + err := c.addAuthInner(cred.scheme, cred.auth) + if err != nil { + c.logger.Printf("Credential re-submit failed: %s", err) + // FIXME(prozlach): lets ignore it here for now + } + } + } + c.credsMu.Unlock() + c.sendSetWatches() wg.Wait() } @@ -722,11 +744,37 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc return r.zxid, r.err } -func (c *Conn) AddAuth(scheme string, auth []byte) error { +func (c *Conn) addAuthInner(scheme string, auth []byte) error { _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) return err } +func (c *Conn) AddAuth(scheme string, auth []byte) error { + err := c.addAuthInner(scheme, auth) + + if err != nil { + return err + } + + // Remember authdata so that it can be re-submitted on reconnect + // + // FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2" + // as two different entries, which will be re-submitted on reconnet. Some + // research is needed on how ZK treats these cases and + // then maybe switch to something like "map[username] = password" to allow + // only single password for given user with users being unique. + obj := authCreds{ + scheme: scheme, + auth: auth, + } + + c.credsMu.Lock() + c.creds = append(c.creds, obj) + c.credsMu.Unlock() + + return nil +} + func (c *Conn) Children(path string) ([]string, *Stat, error) { res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) From 5ccb459a6adc0b315d1fc70fd92cbc90b4d0d793 Mon Sep 17 00:00:00 2001 From: Pawel Rozlach Date: Wed, 31 Aug 2016 17:01:50 +0200 Subject: [PATCH 2/4] Add missing tests --- zk/server_help.go | 22 ++++++++++++++++++++ zk/zk_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/zk/server_help.go b/zk/server_help.go index ef252455..3663064c 100644 --- a/zk/server_help.go +++ b/zk/server_help.go @@ -192,3 +192,25 @@ func (tc *TestCluster) StopServer(server string) { } panic(fmt.Sprintf("Unknown server: %s", server)) } + +func (tc *TestCluster) StartAllServers() error { + for _, s := range tc.Servers { + if err := s.Srv.Start(); err != nil { + return fmt.Errorf( + "Failed to start server listening on port `%d` : %+v", s.Port, err) + } + } + + return nil +} + +func (tc *TestCluster) StopAllServers() error { + for _, s := range tc.Servers { + if err := s.Srv.Stop(); err != nil { + return fmt.Errorf( + "Failed to stop server listening on port `%d` : %+v", s.Port, err) + } + } + + return nil +} diff --git a/zk/zk_test.go b/zk/zk_test.go index 1577798f..781cdd41 100644 --- a/zk/zk_test.go +++ b/zk/zk_test.go @@ -125,6 +125,57 @@ func TestMulti(t *testing.T) { } } +func TestIfAuthdataSurvivesReconnect(t *testing.T) { + // This test case ensures authentication data is being resubmited after + // reconnect. + testNode := "/auth-testnode" + + ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk.Close() + + acl := DigestACL(PermAll, "userfoo", "passbar") + + _, err = zk.Create(testNode, []byte("Some very secret content"), 0, acl) + if err != nil && err != ErrNodeExists { + t.Fatalf("Failed to create test node : %+v", err) + } + + _, _, err = zk.Get(testNode) + if err == nil || err != ErrNoAuth { + var msg string + + if err == nil { + msg = "Fetching data without auth should have resulted in an error" + } else { + msg = fmt.Sprintf("Expecting ErrNoAuth, got `%+v` instead", err) + } + t.Fatalf(msg) + } + + zk.AddAuth("digest", []byte("userfoo:passbar")) + + _, _, err = zk.Get(testNode) + if err != nil { + t.Fatalf("Fetching data with auth failed: %+v", err) + } + + ts.StopAllServers() + ts.StartAllServers() + + _, _, err = zk.Get(testNode) + if err != nil { + t.Fatalf("Fetching data after reconnect failed: %+v", err) + } +} + func TestMultiFailures(t *testing.T) { // This test case ensures that we return the errors associated with each // opeThis in the event a call to Multi() fails. From 79b13fe71d61b851352a5bf79f3c92f80025d653 Mon Sep 17 00:00:00 2001 From: Pawel Rozlach Date: Wed, 31 Aug 2016 17:03:26 +0200 Subject: [PATCH 3/4] Fix race in sendloop --- zk/conn.go | 183 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 119 insertions(+), 64 deletions(-) diff --git a/zk/conn.go b/zk/conn.go index 06ed5a2e..8c533068 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -94,6 +94,7 @@ type Conn struct { requestsLock sync.Mutex watchers map[watchPathType][]chan Event watchersLock sync.Mutex + closeChan chan struct{} // channel to tell send loop stop // Debug (used by unit tests) reconnectDelay time.Duration @@ -325,6 +326,65 @@ func (c *Conn) connect() error { } } +func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) { + c.credsMu.Lock() + defer c.credsMu.Unlock() + + defer close(reauthReadyChan) + + c.logger.Printf("Re-submitting `%d` credentials after reconnect", + len(c.creds)) + + for _, cred := range c.creds { + resChan, err := c.sendRequest( + opSetAuth, + &setAuthRequest{Type: 0, + Scheme: cred.scheme, + Auth: cred.auth, + }, + &setAuthResponse{}, + nil) + + if err != nil { + c.logger.Printf("Call to sendRequest failed during credential resubmit: %s", err) + // FIXME(prozlach): lets ignore errors for now + continue + } + + res := <-resChan + if res.err != nil { + c.logger.Printf("Credential re-submit failed: %s", res.err) + // FIXME(prozlach): lets ignore errors for now + continue + } + } +} + +func (c *Conn) sendRequest( + opcode int32, + req interface{}, + res interface{}, + recvFunc func(*request, *responseHeader, error), +) ( + <-chan response, + error, +) { + rq := &request{ + xid: c.nextXid(), + opcode: opcode, + pkt: req, + recvStruct: res, + recvChan: make(chan response, 1), + recvFunc: recvFunc, + } + + if err := c.sendData(rq); err != nil { + return nil, err + } + + return rq.recvChan, nil +} + func (c *Conn) loop() { for { if err := c.connect(); err != nil { @@ -342,13 +402,15 @@ func (c *Conn) loop() { c.conn.Close() case err == nil: c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) - c.hostProvider.Connected() // mark success - closeChan := make(chan struct{}) // channel to tell send loop stop - var wg sync.WaitGroup + c.hostProvider.Connected() // mark success + c.closeChan = make(chan struct{}) // channel to tell send loop stop + reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted + var wg sync.WaitGroup wg.Add(1) go func() { - err := c.sendLoop(c.conn, closeChan) + <-reauthChan + err := c.sendLoop() c.logger.Printf("Send loop terminated: err=%v", err) c.conn.Close() // causes recv loop to EOF/exit wg.Done() @@ -361,23 +423,11 @@ func (c *Conn) loop() { if err == nil { panic("zk: recvLoop should never return nil error") } - close(closeChan) // tell send loop to exit + close(c.closeChan) // tell send loop to exit wg.Done() }() - c.credsMu.Lock() - if len(c.creds) > 0 { - c.logger.Printf("Re-submitting %d credentials after reconnect", - len(c.creds)) - for _, cred := range c.creds { - err := c.addAuthInner(cred.scheme, cred.auth) - if err != nil { - c.logger.Printf("Credential re-submit failed: %s", err) - // FIXME(prozlach): lets ignore it here for now - } - } - } - c.credsMu.Unlock() + c.resendZkAuth(reauthChan) c.sendSetWatches() wg.Wait() @@ -550,7 +600,50 @@ func (c *Conn) authenticate() error { return nil } -func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan struct{}) error { +func (c *Conn) sendData(req *request) error { + buf := make([]byte, bufferSize) + + header := &requestHeader{req.xid, req.opcode} + n, err := encodePacket(buf[4:], header) + if err != nil { + req.recvChan <- response{-1, err} + return nil + } + + n2, err := encodePacket(buf[4+n:], req.pkt) + if err != nil { + req.recvChan <- response{-1, err} + return nil + } + + n += n2 + + binary.BigEndian.PutUint32(buf[:4], uint32(n)) + + c.requestsLock.Lock() + select { + case <-c.closeChan: + req.recvChan <- response{-1, ErrConnectionClosed} + c.requestsLock.Unlock() + return ErrConnectionClosed + default: + } + c.requests[req.xid] = req + c.requestsLock.Unlock() + + c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) + _, err = c.conn.Write(buf[:n+4]) + c.conn.SetWriteDeadline(time.Time{}) + if err != nil { + req.recvChan <- response{-1, err} + c.conn.Close() + return err + } + + return nil +} + +func (c *Conn) sendLoop() error { pingTicker := time.NewTicker(c.pingInterval) defer pingTicker.Stop() @@ -558,40 +651,7 @@ func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan struct{}) error { for { select { case req := <-c.sendChan: - header := &requestHeader{req.xid, req.opcode} - n, err := encodePacket(buf[4:], header) - if err != nil { - req.recvChan <- response{-1, err} - continue - } - - n2, err := encodePacket(buf[4+n:], req.pkt) - if err != nil { - req.recvChan <- response{-1, err} - continue - } - - n += n2 - - binary.BigEndian.PutUint32(buf[:4], uint32(n)) - - c.requestsLock.Lock() - select { - case <-closeChan: - req.recvChan <- response{-1, ErrConnectionClosed} - c.requestsLock.Unlock() - return ErrConnectionClosed - default: - } - c.requests[req.xid] = req - c.requestsLock.Unlock() - - conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) - _, err = conn.Write(buf[:n+4]) - conn.SetWriteDeadline(time.Time{}) - if err != nil { - req.recvChan <- response{-1, err} - conn.Close() + if err := c.sendData(req); err != nil { return err } case <-pingTicker.C: @@ -602,14 +662,14 @@ func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan struct{}) error { binary.BigEndian.PutUint32(buf[:4], uint32(n)) - conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) - _, err = conn.Write(buf[:n+4]) - conn.SetWriteDeadline(time.Time{}) + c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) + _, err = c.conn.Write(buf[:n+4]) + c.conn.SetWriteDeadline(time.Time{}) if err != nil { - conn.Close() + c.conn.Close() return err } - case <-closeChan: + case <-c.closeChan: return nil } } @@ -744,13 +804,8 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc return r.zxid, r.err } -func (c *Conn) addAuthInner(scheme string, auth []byte) error { - _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) - return err -} - func (c *Conn) AddAuth(scheme string, auth []byte) error { - err := c.addAuthInner(scheme, auth) + _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) if err != nil { return err From cc4edb777d5cf806ad064dbe2634308de162e7f6 Mon Sep 17 00:00:00 2001 From: Pawel Rozlach Date: Fri, 2 Sep 2016 22:36:04 +0200 Subject: [PATCH 4/4] Move helper buffer into connection struct --- zk/conn.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/zk/conn.go b/zk/conn.go index 8c533068..b6b8dbc1 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -100,6 +100,8 @@ type Conn struct { reconnectDelay time.Duration logger Logger + + buf []byte } // connOption represents a connection option. @@ -195,6 +197,7 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti watchers: make(map[watchPathType][]chan Event), passwd: emptyPassword, logger: DefaultLogger, + buf: make([]byte, bufferSize), // Debug reconnectDelay: 0, @@ -601,16 +604,14 @@ func (c *Conn) authenticate() error { } func (c *Conn) sendData(req *request) error { - buf := make([]byte, bufferSize) - header := &requestHeader{req.xid, req.opcode} - n, err := encodePacket(buf[4:], header) + n, err := encodePacket(c.buf[4:], header) if err != nil { req.recvChan <- response{-1, err} return nil } - n2, err := encodePacket(buf[4+n:], req.pkt) + n2, err := encodePacket(c.buf[4+n:], req.pkt) if err != nil { req.recvChan <- response{-1, err} return nil @@ -618,7 +619,7 @@ func (c *Conn) sendData(req *request) error { n += n2 - binary.BigEndian.PutUint32(buf[:4], uint32(n)) + binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) c.requestsLock.Lock() select { @@ -632,7 +633,7 @@ func (c *Conn) sendData(req *request) error { c.requestsLock.Unlock() c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) - _, err = c.conn.Write(buf[:n+4]) + _, err = c.conn.Write(c.buf[:n+4]) c.conn.SetWriteDeadline(time.Time{}) if err != nil { req.recvChan <- response{-1, err} @@ -647,7 +648,6 @@ func (c *Conn) sendLoop() error { pingTicker := time.NewTicker(c.pingInterval) defer pingTicker.Stop() - buf := make([]byte, bufferSize) for { select { case req := <-c.sendChan: @@ -655,15 +655,15 @@ func (c *Conn) sendLoop() error { return err } case <-pingTicker.C: - n, err := encodePacket(buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) + n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) if err != nil { panic("zk: opPing should never fail to serialize") } - binary.BigEndian.PutUint32(buf[:4], uint32(n)) + binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) - _, err = c.conn.Write(buf[:n+4]) + _, err = c.conn.Write(c.buf[:n+4]) c.conn.SetWriteDeadline(time.Time{}) if err != nil { c.conn.Close()