From f2d0c37535c6f7986b48f4a9d4ae011ac855efdc Mon Sep 17 00:00:00 2001 From: Samuel Stauffer Date: Wed, 12 Dec 2012 17:49:29 -0800 Subject: [PATCH] Implement Close --- conn.go | 55 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 06ab4065..2b7325a7 100644 --- a/conn.go +++ b/conn.go @@ -35,6 +35,7 @@ type Conn struct { conn net.Conn state State eventChan chan Event + shouldQuit chan bool pingInterval time.Duration recvTimeout time.Duration connectTimeout time.Duration @@ -78,6 +79,7 @@ func Connect(servers []string, recvTimeout time.Duration) (*Conn, <-chan Event, conn: nil, state: StateDisconnected, eventChan: ec, + shouldQuit: make(chan bool), recvTimeout: recvTimeout, pingInterval: 10 * time.Second, connectTimeout: 1 * time.Second, @@ -92,11 +94,8 @@ func Connect(servers []string, recvTimeout time.Duration) (*Conn, <-chan Event, } func (c *Conn) Close() { - // TODO - - // if c.conn != nil { - // c.conn.Close() - // } + close(c.shouldQuit) + c.disconnect() } func (c *Conn) connect() { @@ -120,25 +119,43 @@ func (c *Conn) connect() { } } +func (c *Conn) disconnect() { + c.request(-1, nil, nil) +} + func (c *Conn) loop() { for { c.connect() err := c.authenticate() if err == nil { closeChan := make(chan bool) - go c.sendLoop(c.conn, closeChan) - err = c.recvLoop(c.conn) - if err == nil { - panic("zk: recvLoop should never return nil error") - } - close(closeChan) + sendDone := make(chan bool, 1) + go func() { + c.sendLoop(c.conn, closeChan) + c.conn.Close() + close(sendDone) + }() + + recvDone := make(chan bool, 1) + go func() { + err = c.recvLoop(c.conn) + if err == nil { + panic("zk: recvLoop should never return nil error") + } + close(closeChan) + <-sendDone // wait for send loop to exit + close(recvDone) + }() + + <-recvDone } - c.conn.Close() c.state = StateDisconnected c.eventChan <- Event{EventSession, c.state, ""} - log.Println(err) + if !strings.Contains(err.Error(), "use of closed network connection") { + log.Println(err) + } c.requestsLock.Lock() // Error out any pending requests @@ -147,6 +164,12 @@ func (c *Conn) loop() { } c.requests = make(map[int32]*request) c.requestsLock.Unlock() + + select { + case <-c.shouldQuit: + return + default: + } } } @@ -223,6 +246,12 @@ func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan bool) error { for { select { case req := <-c.sendChan: + if req.opcode < 0 { + // Asked to quit + req.recvChan <- nil + return nil + } + header := &requestHeader{req.xid, req.opcode} n, err := encodePacket(buf[4:], header) if err != nil {