Skip to content

Commit

Permalink
Merge pull request moby#33859 from tonistiigi/session-includepaths
Browse files Browse the repository at this point in the history
Add path filtering to build session client
  • Loading branch information
thaJeztah authored Jul 6, 2017
2 parents db8c265 + 4141d8f commit 19ee873
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 22 deletions.
3 changes: 1 addition & 2 deletions builder/dockerfile/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ func (bm *BuildManager) initializeClientSession(ctx context.Context, cancel func
}()
if options.RemoteContext == remotecontext.ClientSessionRemote {
st := time.Now()
csi, err := NewClientSessionSourceIdentifier(ctx, bm.sg,
options.SessionID, []string{"/"})
csi, err := NewClientSessionSourceIdentifier(ctx, bm.sg, options.SessionID)
if err != nil {
return nil, err
}
Expand Down
19 changes: 9 additions & 10 deletions builder/dockerfile/clientsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,25 @@ func (cst *ClientSessionTransport) Copy(ctx context.Context, id fscache.RemoteId
}

return filesync.FSSync(ctx, csi.caller, filesync.FSSendRequestOpt{
SrcPaths: csi.srcPaths,
DestDir: dest,
CacheUpdater: cu,
IncludePatterns: csi.includePatterns,
DestDir: dest,
CacheUpdater: cu,
})
}

// ClientSessionSourceIdentifier is an identifier that can be used for requesting
// files from remote client
type ClientSessionSourceIdentifier struct {
srcPaths []string
caller session.Caller
sharedKey string
uuid string
includePatterns []string
caller session.Caller
sharedKey string
uuid string
}

// NewClientSessionSourceIdentifier returns new ClientSessionSourceIdentifier instance
func NewClientSessionSourceIdentifier(ctx context.Context, sg SessionGetter, uuid string, sources []string) (*ClientSessionSourceIdentifier, error) {
func NewClientSessionSourceIdentifier(ctx context.Context, sg SessionGetter, uuid string) (*ClientSessionSourceIdentifier, error) {
csi := &ClientSessionSourceIdentifier{
uuid: uuid,
srcPaths: sources,
uuid: uuid,
}
caller, err := sg.Get(ctx, uuid)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion client/session/filesync/diffcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (
"github.com/tonistiigi/fsutil"
)

func sendDiffCopy(stream grpc.Stream, dir string, excludes []string, progress progressCb) error {
func sendDiffCopy(stream grpc.Stream, dir string, includes, excludes []string, progress progressCb) error {
return fsutil.Send(stream.Context(), stream, dir, &fsutil.WalkOpt{
ExcludePatterns: excludes,
IncludePaths: includes, // TODO: rename IncludePatterns
}, progress)
}

Expand Down
20 changes: 15 additions & 5 deletions client/session/filesync/filesync.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
"google.golang.org/grpc/metadata"
)

const (
keyOverrideExcludes = "override-excludes"
keyIncludePatterns = "include-patterns"
)

type fsSyncProvider struct {
root string
excludes []string
Expand Down Expand Up @@ -54,9 +59,10 @@ func (sp *fsSyncProvider) handle(method string, stream grpc.ServerStream) error
opts, _ := metadata.FromContext(stream.Context()) // if no metadata continue with empty object

var excludes []string
if len(opts["Override-Excludes"]) == 0 || opts["Override-Excludes"][0] != "true" {
if len(opts[keyOverrideExcludes]) == 0 || opts[keyOverrideExcludes][0] != "true" {
excludes = sp.excludes
}
includes := opts[keyIncludePatterns]

var progress progressCb
if sp.p != nil {
Expand All @@ -69,7 +75,7 @@ func (sp *fsSyncProvider) handle(method string, stream grpc.ServerStream) error
doneCh = sp.doneCh
sp.doneCh = nil
}
err := pr.sendFn(stream, sp.root, excludes, progress)
err := pr.sendFn(stream, sp.root, includes, excludes, progress)
if doneCh != nil {
if err != nil {
doneCh <- err
Expand All @@ -88,7 +94,7 @@ type progressCb func(int, bool)

type protocol struct {
name string
sendFn func(stream grpc.Stream, srcDir string, excludes []string, progress progressCb) error
sendFn func(stream grpc.Stream, srcDir string, includes, excludes []string, progress progressCb) error
recvFn func(stream grpc.Stream, destDir string, cu CacheUpdater) error
}

Expand All @@ -115,7 +121,7 @@ var supportedProtocols = []protocol{

// FSSendRequestOpt defines options for FSSend request
type FSSendRequestOpt struct {
SrcPaths []string
IncludePatterns []string
OverrideExcludes bool
DestDir string
CacheUpdater CacheUpdater
Expand All @@ -142,7 +148,11 @@ func FSSync(ctx context.Context, c session.Caller, opt FSSendRequestOpt) error {

opts := make(map[string][]string)
if opt.OverrideExcludes {
opts["Override-Excludes"] = []string{"true"}
opts[keyOverrideExcludes] = []string{"true"}
}

if opt.IncludePatterns != nil {
opts[keyIncludePatterns] = opt.IncludePatterns
}

ctx, cancel := context.WithCancel(ctx)
Expand Down
71 changes: 71 additions & 0 deletions client/session/filesync/filesync_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package filesync

import (
"context"
"io/ioutil"
"path/filepath"
"testing"

"github.com/docker/docker/client/session"
"github.com/docker/docker/client/session/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)

func TestFileSyncIncludePatterns(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "fsynctest")
require.NoError(t, err)

destDir, err := ioutil.TempDir("", "fsynctest")
require.NoError(t, err)

err = ioutil.WriteFile(filepath.Join(tmpDir, "foo"), []byte("content1"), 0600)
require.NoError(t, err)

err = ioutil.WriteFile(filepath.Join(tmpDir, "bar"), []byte("content2"), 0600)
require.NoError(t, err)

s, err := session.NewSession("foo", "bar")
require.NoError(t, err)

m, err := session.NewManager()
require.NoError(t, err)

fs := NewFSSyncProvider(tmpDir, nil)
s.Allow(fs)

dialer := session.Dialer(testutil.TestStream(testutil.Handler(m.HandleConn)))

g, ctx := errgroup.WithContext(context.Background())

g.Go(func() error {
return s.Run(ctx, dialer)
})

g.Go(func() (reterr error) {
c, err := m.Get(ctx, s.UUID())
if err != nil {
return err
}
if err := FSSync(ctx, c, FSSendRequestOpt{
DestDir: destDir,
IncludePatterns: []string{"ba*"},
}); err != nil {
return err
}

_, err = ioutil.ReadFile(filepath.Join(destDir, "foo"))
assert.Error(t, err)

dt, err := ioutil.ReadFile(filepath.Join(destDir, "bar"))
if err != nil {
return err
}
assert.Equal(t, "content2", string(dt))
return s.Close()
})

err = g.Wait()
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion client/session/filesync/tarstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"google.golang.org/grpc"
)

func sendTarStream(stream grpc.Stream, dir string, excludes []string, progress progressCb) error {
func sendTarStream(stream grpc.Stream, dir string, includes, excludes []string, progress progressCb) error {
a, err := archive.TarWithOptions(dir, &archive.TarOptions{
ExcludePatterns: excludes,
})
Expand Down
21 changes: 18 additions & 3 deletions client/session/manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package session

import (
"net"
"net/http"
"strings"
"sync"
Expand Down Expand Up @@ -49,8 +50,6 @@ func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter,
}

uuid := r.Header.Get(headerSessionUUID)
name := r.Header.Get(headerSessionName)
sharedKey := r.Header.Get(headerSessionSharedKey)

proto := r.Header.Get("Upgrade")

Expand Down Expand Up @@ -89,9 +88,25 @@ func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter,
conn.Write([]byte{})
resp.Write(conn)

return sm.handleConn(ctx, conn, r.Header)
}

// HandleConn handles an incoming raw connection
func (sm *Manager) HandleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error {
sm.mu.Lock()
return sm.handleConn(ctx, conn, opts)
}

// caller needs to take lock, this function will release it
func (sm *Manager) handleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

h := http.Header(opts)
uuid := h.Get(headerSessionUUID)
name := h.Get(headerSessionName)
sharedKey := h.Get(headerSessionSharedKey)

ctx, cc, err := grpcClientConn(ctx, conn)
if err != nil {
sm.mu.Unlock()
Expand All @@ -111,7 +126,7 @@ func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter,
supported: make(map[string]struct{}),
}

for _, m := range r.Header[headerSessionMethod] {
for _, m := range opts[headerSessionMethod] {
c.supported[strings.ToLower(m)] = struct{}{}
}
sm.sessions[uuid] = c
Expand Down
70 changes: 70 additions & 0 deletions client/session/testutil/testutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package testutil

import (
"io"
"net"
"time"

"github.com/Sirupsen/logrus"
"golang.org/x/net/context"
)

// Handler is function called to handle incoming connection
type Handler func(ctx context.Context, conn net.Conn, meta map[string][]string) error

// Dialer is a function for dialing an outgoing connection
type Dialer func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error)

// TestStream creates an in memory session dialer for a handler function
func TestStream(handler Handler) Dialer {
s1, s2 := sockPair()
return func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error) {
go func() {
err := handler(context.TODO(), s1, meta)
if err != nil {
logrus.Error(err)
}
s1.Close()
}()
return s2, nil
}
}

func sockPair() (*sock, *sock) {
pr1, pw1 := io.Pipe()
pr2, pw2 := io.Pipe()
return &sock{pw1, pr2, pw1}, &sock{pw2, pr1, pw2}
}

type sock struct {
io.Writer
io.Reader
io.Closer
}

func (s *sock) LocalAddr() net.Addr {
return dummyAddr{}
}
func (s *sock) RemoteAddr() net.Addr {
return dummyAddr{}
}
func (s *sock) SetDeadline(t time.Time) error {
return nil
}
func (s *sock) SetReadDeadline(t time.Time) error {
return nil
}
func (s *sock) SetWriteDeadline(t time.Time) error {
return nil
}

type dummyAddr struct {
}

func (d dummyAddr) Network() string {
return "tcp"
}

func (d dummyAddr) String() string {
return "localhost"
}

0 comments on commit 19ee873

Please sign in to comment.