From 4da5b4affcdeafe73e574b687d00768ea42f0342 Mon Sep 17 00:00:00 2001 From: Samuel Stauffer Date: Wed, 29 May 2013 22:24:00 -0700 Subject: [PATCH] Add a protected create for ephemeral/sequential nodes --- conn.go | 71 ++++++++++++++++++++++++++++++++++++++++++++-------- lock.go | 62 +++++++++++++++++++++++++++------------------ lock_test.go | 6 +++-- 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/conn.go b/conn.go index 15166e8c..7f4d77a7 100644 --- a/conn.go +++ b/conn.go @@ -3,7 +3,9 @@ package zk // TODO: make sure a ping response comes back in a reasonable time import ( + "crypto/rand" "encoding/binary" + "fmt" "io" "log" "net" @@ -15,9 +17,10 @@ import ( ) const ( - bufferSize = 1536 * 1024 - eventChanSize = 5 - sendChanSize = 16 + bufferSize = 1536 * 1024 + eventChanSize = 5 + sendChanSize = 16 + protectedPrefix = "_c_" ) type watcherType int @@ -187,7 +190,11 @@ func (c *Conn) loop() { log.Println(err) } - c.flushRequests(err) + if err == ErrSessionExpired { + c.flushRequests(err) + } else { + c.flushRequests(ErrConnectionClosed) + } if c.reconnectDelay > 0 { select { @@ -377,13 +384,6 @@ func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan bool) error { binary.BigEndian.PutUint32(buf[:4], uint32(n)) - _, err = conn.Write(buf[:n+4]) - if err != nil { - req.recvChan <- err - conn.Close() - return err - } - c.requestsLock.Lock() select { case <-closeChan: @@ -394,6 +394,13 @@ func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan bool) error { } c.requests[req.xid] = req c.requestsLock.Unlock() + + _, err = conn.Write(buf[:n+4]) + if err != nil { + req.recvChan <- err + conn.Close() + return err + } case <-pingTicker.C: n, err := encodePacket(buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) if err != nil { @@ -596,6 +603,48 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return res.Path, err } +// Fixes a hole if the server crashes after it creates the node +func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) { + var guid [16]byte + _, err := io.ReadFull(rand.Reader, guid[:16]) + if err != nil { + return "", err + } + guidStr := fmt.Sprintf("%x", guid) + + parts := strings.Split(path, "/") + parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1]) + rootPath := strings.Join(parts[:len(parts)-1], "/") + protectedPath := strings.Join(parts, "/") + + res := &createResponse{} + for i := 0; i < 3; i++ { + err = c.request(opCreate, &createRequest{protectedPath, data, acl, FlagEphemeral | FlagSequence}, res) + switch err { + case ErrSessionExpired: + // No need to search for the node since it can't exist. Just try again. + case ErrConnectionClosed: + children, _, err := c.Children(rootPath) + if err != nil { + return "", err + } + for _, p := range children { + parts := strings.Split(p, "/") + if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) { + if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr { + return rootPath + "/" + p, nil + } + } + } + case nil: + return res.Path, nil + default: + return "", err + } + } + return "", err +} + func (c *Conn) Delete(path string, version int32) error { res := &deleteResponse{} return c.request(opDelete, &deleteRequest{path, version}, res) diff --git a/lock.go b/lock.go index 0f03a6a4..dffb3618 100644 --- a/lock.go +++ b/lock.go @@ -1,10 +1,8 @@ package zk import ( - "crypto/rand" "errors" "fmt" - "io" "strconv" "strings" ) @@ -15,18 +13,18 @@ var ( ) type Lock struct { - c *Conn - name string - path string - seq int - guid [16]byte + c *Conn + path string + acl []ACL + lockPath string + seq int } -func NewLock(c *Conn, name string) *Lock { +func NewLock(c *Conn, path string, acl []ACL) *Lock { return &Lock{ c: c, - name: name, - path: "", + path: path, + acl: acl, } } @@ -36,22 +34,36 @@ func parseSeq(path string) (int, error) { } func (l *Lock) Lock() error { - if l.path != "" { + if l.lockPath != "" { return ErrDeadlock } - _, err := io.ReadFull(rand.Reader, l.guid[:16]) - if err != nil { - return err + prefix := fmt.Sprintf("%s/lock-", l.path) + + path := "" + var err error + for i := 0; i < 3; i++ { + path, err = l.c.CreateProtectedEphemeralSequential(prefix, []byte{}, l.acl) + if err == ErrNoNode { + // Create parent node. + parts := strings.Split(l.path, "/") + pth := "" + for _, p := range parts[1:] { + pth += "/" + p + _, err := l.c.Create(pth, []byte{}, 0, l.acl) + if err != nil { + return err + } + } + } else if err == nil { + break + } else { + return err + } } - - basePath := fmt.Sprintf("/locks/%s", l.name) - prefix := fmt.Sprintf("%s/%x-", basePath, l.guid) - path, err := l.c.Create(prefix, []byte("lock"), FlagEphemeral|FlagSequence, WorldACL(PermAll)) if err != nil { return err } - // TODO: handle recoverable errors seq, err := parseSeq(path) if err != nil { @@ -59,7 +71,7 @@ func (l *Lock) Lock() error { } for { - children, _, err := l.c.Children(basePath) + children, _, err := l.c.Children(l.path) if err != nil { return err } @@ -86,7 +98,7 @@ func (l *Lock) Lock() error { break } - exists, _, ch, err := l.c.ExistsW(prevSeqPath) + exists, _, ch, err := l.c.ExistsW(l.path + "/" + prevSeqPath) if err != nil { return err } @@ -100,18 +112,18 @@ func (l *Lock) Lock() error { } l.seq = seq - l.path = path + l.lockPath = path return nil } func (l *Lock) Unlock() error { - if l.path == "" { + if l.lockPath == "" { return ErrNotLocked } - if err := l.c.Delete(l.path, -1); err != nil { + if err := l.c.Delete(l.lockPath, -1); err != nil { return err } - l.path = "" + l.lockPath = "" l.seq = 0 return nil } diff --git a/lock_test.go b/lock_test.go index bee9cb05..4bd35caa 100644 --- a/lock_test.go +++ b/lock_test.go @@ -12,7 +12,9 @@ func TestLock(t *testing.T) { } defer zk.Close() - l := NewLock(zk, "test") + acls := WorldACL(PermAll) + + l := NewLock(zk, "/test", acls) if err := l.Lock(); err != nil { t.Fatal(err) } @@ -26,7 +28,7 @@ func TestLock(t *testing.T) { t.Fatal(err) } - l2 := NewLock(zk, "test") + l2 := NewLock(zk, "/test", acls) go func() { if err := l2.Lock(); err != nil { t.Fatal(err)