Skip to content

Commit

Permalink
Check authorization in WatchDocument (yorkie-team#211)
Browse files Browse the repository at this point in the history
If an error occurs before starting the watch, there is no way to deliver
the error when waiting synchronously.

So, watch-started event is added and start the watch asynchronously.

Co-authored-by: Dongcheol Choe <[email protected]>
  • Loading branch information
hackerwins and dc7303 committed Jul 11, 2021
1 parent 7aafb1d commit a9c60b2
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 52 deletions.
17 changes: 2 additions & 15 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type WatchResponseType string

// The values below are types of WatchResponseType.
const (
WatchStarted WatchResponseType = "watch-started"
DocumentsChanged WatchResponseType = "documents-changed"
PeersChanged WatchResponseType = "peers-changed"
)
Expand Down Expand Up @@ -411,7 +412,7 @@ func (c *Client) Watch(ctx context.Context, docs ...*document.Document) <-chan W
}

return &WatchResponse{
Type: PeersChanged,
Type: WatchStarted,
PeersMapByDoc: c.PeersMapByDoc(),
}, nil
case *api.WatchDocumentsResponse_Event:
Expand Down Expand Up @@ -449,20 +450,6 @@ func (c *Client) Watch(ctx context.Context, docs ...*document.Document) <-chan W
return nil, fmt.Errorf("unsupported response type")
}

// waiting for starting watch
pbResp, err := stream.Recv()
if err != nil {
rch <- WatchResponse{Err: err}
close(rch)
return rch
}
if _, err := handleResponse(pbResp); err != nil {
rch <- WatchResponse{Err: err}
close(rch)
return rch
}

// starting to watch documents
go func() {
for {
pbResp, err := stream.Recv()
Expand Down
2 changes: 2 additions & 0 deletions pkg/types/auth_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const (
AttachDocument Method = "AttachDocument"
DetachDocument Method = "DetachDocument"
PushPull Method = "PushPull"
WatchDocuments Method = "WatchDocuments"
)

// IsAuthMethod returns whether the given method can be used for authorization.
Expand All @@ -73,6 +74,7 @@ func AuthMethods() []Method {
AttachDocument,
DetachDocument,
PushPull,
WatchDocuments,
}
}

Expand Down
19 changes: 18 additions & 1 deletion test/integration/auth_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/rs/xid"
Expand Down Expand Up @@ -96,7 +97,10 @@ func TestAuthWebhook(t *testing.T) {
server, _ := newAuthServer(t)

conf := helper.TestConfig(server.URL)
conf.Backend.AuthorizationWebhookMethods = []string{string(types.AttachDocument)}
conf.Backend.AuthorizationWebhookMethods = []string{
string(types.AttachDocument),
string(types.WatchDocuments),
}

agent, err := yorkie.New(conf)
assert.NoError(t, err)
Expand All @@ -114,5 +118,18 @@ func TestAuthWebhook(t *testing.T) {
doc := document.New(helper.Collection, t.Name())
err = cli.Attach(ctx, doc)
assert.Equal(t, codes.Unauthenticated, status.Convert(err).Code())

wg := sync.WaitGroup{}

wg.Add(1)
rch := cli.Watch(ctx, doc)
go func() {
defer wg.Done()

resp := <- rch
assert.Equal(t, codes.Unauthenticated, status.Convert(resp.Err).Code())
}()

wg.Wait()
})
}
20 changes: 11 additions & 9 deletions test/integration/cluster_mode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,19 @@ func TestClusterMode(t *testing.T) {
go func() {
defer wg.Done()

select {
case resp := <-rch:
if resp.Err == io.EOF {
for {
select {
case resp := <-rch:
if resp.Err == io.EOF {
return
}
assert.NoError(t, resp.Err)

err := clientA.Sync(ctx, resp.Keys...)
assert.NoError(t, err)
case <-time.After(time.Second):
return
}
assert.NoError(t, resp.Err)

err := clientA.Sync(ctx, resp.Keys...)
assert.NoError(t, err)
case <-time.After(time.Second):
return
}
}()

Expand Down
77 changes: 53 additions & 24 deletions test/integration/document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"io"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -129,15 +130,21 @@ func TestDocument(t *testing.T) {
go func() {
defer wg.Done()

// receive changed event.
resp := <-rch
if resp.Err == io.EOF {
return
}
assert.NoError(t, resp.Err)
for {
// receive changed event.
resp := <-rch
if resp.Err == io.EOF {
assert.Fail(t, resp.Err.Error())
return
}
assert.NoError(t, resp.Err)

err := c1.Sync(ctx, resp.Keys...)
assert.NoError(t, err)
if resp.Type == client.DocumentsChanged {
err := c1.Sync(ctx, resp.Keys...)
assert.NoError(t, err)
return
}
}
}()

// 02. cli2 updates doc2.
Expand All @@ -155,7 +162,7 @@ func TestDocument(t *testing.T) {
assert.Equal(t, d1.Marshal(), d2.Marshal())
})

t.Run("watch PeersChanged event test", func(t *testing.T) {
t.Run("WatchStarted and PeersChanged event test", func(t *testing.T) {
ctx := context.Background()

d1 := document.New(helper.Collection, t.Name())
Expand All @@ -165,47 +172,69 @@ func TestDocument(t *testing.T) {
assert.NoError(t, c2.Attach(ctx, d2))
defer func() { assert.NoError(t, c2.Detach(ctx, d2)) }()

wg := sync.WaitGroup{}
wgEvents := sync.WaitGroup{}
var types []client.WatchResponseType
wgEvents.Add(1)

// 01. WatchStarted is triggered when starting to watch a document
watch1Ctx, cancel1 := context.WithCancel(ctx)
wrch := c1.Watch(watch1Ctx, d1)
defer cancel1()

wrch := c1.Watch(watch1Ctx, d1)
wgWatchStarted1 := sync.WaitGroup{}
wgWatchStarted1.Add(1)
go func() {
defer wgEvents.Done()
for {
select {
case <-ctx.Done():
assert.Fail(t, "unexpected ctx done")
case <- time.After(time.Second):
assert.Fail(t, "timeout")
return
case wr := <-wrch:
if wr.Err == io.EOF || status.Code(wr.Err) == codes.Canceled {
assert.Fail(t, "unexpected stream closing")
return
}
assert.NoError(t, wr.Err)

if wr.Type == client.PeersChanged {
types = append(types, wr.Type)

if wr.Type == client.WatchStarted {
wgWatchStarted1.Done()
} else if wr.Type == client.PeersChanged {
peers := wr.PeersMapByDoc[d1.Key().BSONKey()]
if len(peers) == 2 {
assert.Equal(t, c2.Metadata(), peers[c2.ID().String()])
wg.Done()
} else if len(peers) == 1 {
assert.Equal(t, c1.Metadata(), peers[c1.ID().String()])
wg.Done()
return
}
}
}
}
}()
wgWatchStarted1.Wait()

// 01. PeersChanged is triggered as a new client watches the document
// 02. PeersChanged is triggered when another client watches the document
watch2Ctx, cancel2 := context.WithCancel(ctx)
wg.Add(1)
_ = c2.Watch(watch2Ctx, d2)
wrch2 := c2.Watch(watch2Ctx, d2)
wgWatchStarted2 := sync.WaitGroup{}
wgWatchStarted2.Add(1)
go func() {
wr := <-wrch2
if wr.Type == client.WatchStarted {
wgWatchStarted2.Done()
return
}
}()
wgWatchStarted2.Wait()

// 02. PeersChanged is triggered because the client closes the watch
wg.Add(1)
// 03. PeersChanged is triggered when another client closes the watch
cancel2()

wg.Wait()
wgEvents.Wait()
assert.Equal(t, []client.WatchResponseType{
client.WatchStarted,
client.PeersChanged,
client.PeersChanged,
}, types)
})
}
12 changes: 11 additions & 1 deletion yorkie/rpc/interceptors/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package interceptors
import (
"context"

grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -67,7 +68,16 @@ func (i *AuthInterceptor) Stream() grpc.StreamServerInterceptor {
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
// TODO(hackerwins): extract token and store it on the context.
if i.needAuth() {
token, err := i.extractToken(ss.Context())
if err != nil {
return err
}
wrapped := grpcmiddleware.WrapServerStream(ss)
wrapped.WrappedContext = auth.CtxWithToken(ss.Context(), token)
return handler(srv, wrapped)
}

return handler(srv, ss)
}
}
Expand Down
4 changes: 2 additions & 2 deletions yorkie/rpc/interceptors/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (i *DefaultInterceptor) Unary() grpc.UnaryServerInterceptor {
start := gotime.Now()
resp, err := handler(ctx, req)
if err != nil {
log.Logger.Errorf("RPC : %q %s: %q => %q", info.FullMethod, gotime.Since(start), req, err)
log.Logger.Warnf("RPC : %q %s: %q => %q", info.FullMethod, gotime.Since(start), req, err)
return nil, toStatusError(err)
}

Expand All @@ -74,7 +74,7 @@ func (i *DefaultInterceptor) Stream() grpc.StreamServerInterceptor {
start := gotime.Now()
err := handler(srv, ss)
if err != nil {
log.Logger.Infof("RPC : stream %q %s => %q", info.FullMethod, gotime.Since(start), err.Error())
log.Logger.Warnf("RPC : stream %q %s => %q", info.FullMethod, gotime.Since(start), err.Error())
return toStatusError(err)
}

Expand Down
14 changes: 14 additions & 0 deletions yorkie/rpc/yorkie_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,20 @@ func (s *yorkieServer) WatchDocuments(
}
docKeys := converter.FromDocumentKeys(req.DocumentKeys)

var attrs []types.AccessAttribute
for _, k := range docKeys {
attrs = append(attrs, types.AccessAttribute{
Key: k.BSONKey(),
Verb: types.Read,
})
}
if err := auth.VerifyAccess(stream.Context(), s.backend, &types.AccessInfo{
Method: types.WatchDocuments,
Attributes: attrs,
}); err != nil {
return err
}

subscription, peersMap, err := s.watchDocs(
stream.Context(),
*client,
Expand Down

0 comments on commit a9c60b2

Please sign in to comment.