Skip to content

Commit

Permalink
internal/gps: resolve race in monitoredCmd
Browse files Browse the repository at this point in the history
*monitoredCmd.combinedOutput and gps.TestMonitoredCmd used to access
monitoredCmd.stdout and monitoredCmd.stdout (both of type
*activityBuffer) without obtaining a lock on activityBuffer.Mutex.

This caused the tests to fail on rare occasions. This change add 2
methods to ensure reads from to the underlying bytes.Buffer happens only
when a lock is obtained.

Signed-off-by: Ibrahim AshShohail <[email protected]>
  • Loading branch information
ibrasho committed Jul 22, 2017
1 parent 3781a6f commit 24a17e0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 39 deletions.
23 changes: 20 additions & 3 deletions internal/gps/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ func (c *monitoredCmd) hasTimedOut() bool {

func (c *monitoredCmd) combinedOutput(ctx context.Context) ([]byte, error) {
if err := c.run(ctx); err != nil {
return c.stderr.buf.Bytes(), err
return c.stderr.Bytes(), err
}

// FIXME(sdboyer) this is not actually combined output
return c.stdout.buf.Bytes(), nil
return c.stdout.Bytes(), nil
}

// activityBuffer is a buffer that keeps track of the last time a Write
Expand All @@ -150,14 +150,31 @@ func newActivityBuffer() *activityBuffer {

func (b *activityBuffer) Write(p []byte) (int, error) {
b.Lock()
b.lastActivityStamp = time.Now()
defer b.Unlock()

b.lastActivityStamp = time.Now()

return b.buf.Write(p)
}

func (b *activityBuffer) String() string {
b.Lock()
defer b.Unlock()

return b.buf.String()
}

func (b *activityBuffer) Bytes() []byte {
b.Lock()
defer b.Unlock()

return b.buf.Bytes()
}

func (b *activityBuffer) lastActivity() time.Time {
b.Lock()
defer b.Unlock()

return b.lastActivityStamp
}

Expand Down
83 changes: 47 additions & 36 deletions internal/gps/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,49 +32,60 @@ func TestMonitoredCmd(t *testing.T) {
}
defer os.Remove("./echosleep")

cmd := mkTestCmd(2)
err = cmd.run(context.Background())
if err != nil {
t.Errorf("Expected command not to fail: %s", err)
tests := []struct {
name string
iterations int
output string
err bool
timeout bool
}{
{"success", 2, "foo\nfoo\n", false, false},
{"timeout", 5, "foo\nfoo\nfoo\nfoo\n", true, true},
}

expectedOutput := "foo\nfoo\n"
if cmd.stdout.buf.String() != expectedOutput {
t.Errorf("Unexpected output:\n\t(GOT): %s\n\t(WNT): %s", cmd.stdout.buf.String(), expectedOutput)
}
for _, want := range tests {
t.Run(want.name, func(t *testing.T) {
cmd := mkTestCmd(want.iterations)

cmd2 := mkTestCmd(10)
err = cmd2.run(context.Background())
if err == nil {
t.Error("Expected command to fail")
}
err := cmd.run(context.Background())
if !want.err && err != nil {
t.Errorf("Eexpected command not to fail, got error: %s", err)
} else if want.err && err == nil {
t.Error("expected command to fail")
}

_, ok := err.(*noProgressError)
if !ok {
t.Errorf("Expected a timeout error, but got: %s", err)
}
got := cmd.stdout.String()
if want.output != got {
t.Errorf("unexpected output:\n\t(GOT):\n%s\n\t(WNT):\n%s", got, want.output)
}

expectedOutput = "foo\nfoo\nfoo\nfoo\n"
if cmd2.stdout.buf.String() != expectedOutput {
t.Errorf("Unexpected output:\n\t(GOT): %s\n\t(WNT): %s", cmd2.stdout.buf.String(), expectedOutput)
if want.timeout {
_, ok := err.(*noProgressError)
if !ok {
t.Errorf("Expected a timeout error, but got: %s", err)
}
}
})
}

ctx, cancel := context.WithCancel(context.Background())
sync1, errchan := make(chan struct{}), make(chan error)
cmd3 := mkTestCmd(2)
go func() {
close(sync1)
errchan <- cmd3.run(ctx)
}()
t.Run("cancel", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
sync, errchan := make(chan struct{}), make(chan error)
cmd := mkTestCmd(2)
go func() {
close(sync)
errchan <- cmd.run(ctx)
}()

// Make sure goroutine is at least started before we cancel the context.
<-sync1
// Give it a bit to get the process started.
<-time.After(5 * time.Millisecond)
cancel()
// Make sure goroutine is at least started before we cancel the context.
<-sync
// Give it a bit to get the process started.
<-time.After(5 * time.Millisecond)
cancel()

err = <-errchan
if err != context.Canceled {
t.Errorf("should have gotten canceled error, got %s", err)
}
err := <-errchan
if err != context.Canceled {
t.Errorf("expected a canceled error, got %s", err)
}
})
}

0 comments on commit 24a17e0

Please sign in to comment.