Skip to content

Commit

Permalink
Add auth support
Browse files Browse the repository at this point in the history
  • Loading branch information
samuel committed May 30, 2013
1 parent 90bc964 commit 477e2f8
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 1 deletion.
29 changes: 28 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ func (c *Conn) request(opcode int32, req interface{}, res interface{}) error {
return <-c.queueRequest(opcode, req, res)
}

func (c *Conn) AddAuth(scheme string, auth []byte) error {
return c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{})
}

func (c *Conn) Children(path string) ([]string, *Stat, error) {
res := &getChildren2Response{}
err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res)
Expand All @@ -592,7 +596,7 @@ func (c *Conn) Get(path string) ([]byte, *Stat, error) {

func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) {
res := &getDataResponse{}
err := c.request(opExists, &getDataRequest{Path: path, Watch: true}, res)
err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res)
var ech chan Event
if err == nil {
ech = c.addWatcher(path, watcherTypeData)
Expand Down Expand Up @@ -659,6 +663,17 @@ func (c *Conn) Delete(path string, version int32) error {
return c.request(opDelete, &deleteRequest{path, version}, res)
}

func (c *Conn) Exists(path string) (bool, *Stat, error) {
res := &existsResponse{}
err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res)
exists := true
if err == ErrNoNode {
exists = false
err = nil
}
return exists, &res.Stat, err
}

func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) {
res := &existsResponse{}
err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res)
Expand All @@ -677,3 +692,15 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) {
}
return exists, &res.Stat, ech, err
}

func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {
res := &getAclResponse{}
err := c.request(opGetAcl, &getAclRequest{Path: path}, res)
return res.Acl, &res.Stat, err
}

func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
res := &setAclResponse{}
err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res)
return &res.Stat, err
}
3 changes: 3 additions & 0 deletions structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ type setWatchesResponse struct{}
type syncRequest pathRequest
type syncResponse pathResponse

type setAuthRequest auth
type setAuthResponse struct{}

type watcherEvent struct {
Type EventType
State State
Expand Down
4 changes: 4 additions & 0 deletions tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func trace(conn1, conn2 net.Conn, client bool) {
cr = &setWatchesRequest{}
case opSync:
cr = &syncRequest{}
case opSetAuth:
cr = &setAuthRequest{}
}
}
} else {
Expand Down Expand Up @@ -129,6 +131,8 @@ func trace(conn1, conn2 net.Conn, client bool) {
cr = &syncResponse{}
case opWatcherEvent:
cr = &watcherEvent{}
case opSetAuth:
cr = &setAuthResponse{}
}
if errnum != 0 {
cr = &struct{}{}
Expand Down
16 changes: 16 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package zk

import (
"crypto/sha1"
"encoding/base64"
"fmt"
)

// AuthACL produces an ACL list containing a single ACL which uses the
// provided permissions, with the scheme "auth", and ID "", which is used
// by ZooKeeper to represent any authenticated user.
Expand All @@ -13,3 +19,13 @@ func AuthACL(perms int32) []ACL {
func WorldACL(perms int32) []ACL {
return []ACL{{perms, "world", "anyone"}}
}

func DigestACL(perms int32, user, password string) []ACL {
userPass := []byte(fmt.Sprintf("%s:%s", user, password))
h := sha1.New()
if n, err := h.Write(userPass); err != nil || n != len(userPass) {
panic("SHA1 failed")
}
digest := base64.StdEncoding.EncodeToString(h.Sum(nil))
return []ACL{{perms, "digest", fmt.Sprintf("%s:%s", user, digest)}}
}
94 changes: 94 additions & 0 deletions zk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,100 @@ func TestCreate(t *testing.T) {
}
}

func TestGetSetACL(t *testing.T) {
zk, _, err := Connect([]string{"127.0.0.1:2182"}, time.Second*15)
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
defer zk.Close()

if err := zk.AddAuth("digest", []byte("blah")); err != nil {
t.Fatalf("AddAuth returned error %+v", err)
}

path := "/gozk-test"

if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
t.Fatalf("Delete returned error: %+v", err)
}
if path, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil {
t.Fatalf("Create returned error: %+v", err)
} else if path != "/gozk-test" {
t.Fatalf("Create returned different path '%s' != '/gozk-test'", path)
}

expected := WorldACL(PermAll)

if acl, stat, err := zk.GetACL(path); err != nil {
t.Fatalf("GetACL returned error %+v", err)
} else if stat == nil {
t.Fatalf("GetACL returned nil Stat")
} else if len(acl) != 1 || expected[0] != acl[0] {
t.Fatalf("GetACL mismatch expected %+v instead of %+v", expected, acl)
}

expected = []ACL{{PermAll, "ip", "127.0.0.1"}}

if stat, err := zk.SetACL(path, expected, -1); err != nil {
t.Fatalf("SetACL returned error %+v", err)
} else if stat == nil {
t.Fatalf("SetACL returned nil Stat")
}

if acl, stat, err := zk.GetACL(path); err != nil {
t.Fatalf("GetACL returned error %+v", err)
} else if stat == nil {
t.Fatalf("GetACL returned nil Stat")
} else if len(acl) != 1 || expected[0] != acl[0] {
t.Fatalf("GetACL mismatch expected %+v instead of %+v", expected, acl)
}
}

func TestAuth(t *testing.T) {
zk, _, err := Connect([]string{"127.0.0.1:2182"}, time.Second*15)
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
defer zk.Close()

path := "/gozk-digest-test"
if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
t.Fatalf("Delete returned error: %+v", err)
}

acl := DigestACL(PermAll, "user", "password")

if p, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, acl); err != nil {
t.Fatalf("Create returned error: %+v", err)
} else if p != path {
t.Fatalf("Create returned different path '%s' != '%s'", p, path)
}

if a, stat, err := zk.GetACL(path); err != nil {
t.Fatalf("GetACL returned error %+v", err)
} else if stat == nil {
t.Fatalf("GetACL returned nil Stat")
} else if len(a) != 1 || acl[0] != a[0] {
t.Fatalf("GetACL mismatch expected %+v instead of %+v", acl, a)
}

if _, _, err := zk.Get(path); err != ErrNoAuth {
t.Fatalf("Get returned error %+v instead of ErrNoAuth", err)
}

if err := zk.AddAuth("digest", []byte("user:password")); err != nil {
t.Fatalf("AddAuth returned error %+v", err)
}

if data, stat, err := zk.Get(path); err != nil {
t.Fatalf("Get returned error %+v", err)
} else if stat == nil {
t.Fatalf("Get returned nil Stat")
} else if len(data) != 4 {
t.Fatalf("Get returned wrong data length")
}
}

func TestChildWatch(t *testing.T) {
zk, _, err := Connect([]string{"127.0.0.1:2182"}, time.Second*15)
if err != nil {
Expand Down

0 comments on commit 477e2f8

Please sign in to comment.