Skip to content

Commit

Permalink
Pass message context through the VM interface (ava-labs#2219)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Nov 16, 2022
1 parent 084d4d5 commit 5be9266
Show file tree
Hide file tree
Showing 210 changed files with 4,556 additions and 3,812 deletions.
5 changes: 3 additions & 2 deletions api/admin/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,11 @@ type LoadVMsReply struct {
}

// LoadVMs loads any new VMs available to the node and returns the added VMs.
func (service *Admin) LoadVMs(_ *http.Request, _ *struct{}, reply *LoadVMsReply) error {
func (service *Admin) LoadVMs(r *http.Request, _ *struct{}, reply *LoadVMsReply) error {
service.Log.Debug("Admin: LoadVMs called")

loadedVMs, failedVMs, err := service.VMRegistry.ReloadWithReadLock()
ctx := r.Context()
loadedVMs, failedVMs, err := service.VMRegistry.ReloadWithReadLock(ctx)
if err != nil {
return err
}
Expand Down
13 changes: 7 additions & 6 deletions api/admin/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package admin

import (
"errors"
"net/http"
"testing"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -69,13 +70,13 @@ func TestLoadVMsSuccess(t *testing.T) {
}

resources.mockLog.EXPECT().Debug(gomock.Any()).Times(1)
resources.mockVMRegistry.EXPECT().ReloadWithReadLock().Times(1).Return(newVMs, failedVMs, nil)
resources.mockVMRegistry.EXPECT().ReloadWithReadLock(gomock.Any()).Times(1).Return(newVMs, failedVMs, nil)
resources.mockVMManager.EXPECT().Aliases(id1).Times(1).Return(alias1, nil)
resources.mockVMManager.EXPECT().Aliases(id2).Times(1).Return(alias2, nil)

// execute test
reply := LoadVMsReply{}
err := resources.admin.LoadVMs(nil, nil, &reply)
err := resources.admin.LoadVMs(&http.Request{}, nil, &reply)

require.Equal(t, expectedVMRegistry, reply.NewVMs)
require.Equal(t, err, nil)
Expand All @@ -88,10 +89,10 @@ func TestLoadVMsReloadFails(t *testing.T) {

resources.mockLog.EXPECT().Debug(gomock.Any()).Times(1)
// Reload fails
resources.mockVMRegistry.EXPECT().ReloadWithReadLock().Times(1).Return(nil, nil, errOops)
resources.mockVMRegistry.EXPECT().ReloadWithReadLock(gomock.Any()).Times(1).Return(nil, nil, errOops)

reply := LoadVMsReply{}
err := resources.admin.LoadVMs(nil, nil, &reply)
err := resources.admin.LoadVMs(&http.Request{}, nil, &reply)

require.Equal(t, err, errOops)
}
Expand All @@ -111,12 +112,12 @@ func TestLoadVMsGetAliasesFails(t *testing.T) {
alias1 := []string{id1.String(), "vm1-alias-1", "vm1-alias-2"}

resources.mockLog.EXPECT().Debug(gomock.Any()).Times(1)
resources.mockVMRegistry.EXPECT().ReloadWithReadLock().Times(1).Return(newVMs, failedVMs, nil)
resources.mockVMRegistry.EXPECT().ReloadWithReadLock(gomock.Any()).Times(1).Return(newVMs, failedVMs, nil)
resources.mockVMManager.EXPECT().Aliases(id1).Times(1).Return(alias1, nil)
resources.mockVMManager.EXPECT().Aliases(id2).Times(1).Return(nil, errOops)

reply := LoadVMsReply{}
err := resources.admin.LoadVMs(nil, nil, &reply)
err := resources.admin.LoadVMs(&http.Request{}, nil, &reply)

require.Equal(t, err, errOops)
}
12 changes: 8 additions & 4 deletions api/health/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@

package health

import "context"

var _ Checker = CheckerFunc(nil)

// Checker can have its health checked
type Checker interface {
// HealthCheck returns health check results and, if not healthy, a non-nil
// error
//
// It is expected that the results are json marshallable.
HealthCheck() (interface{}, error)
HealthCheck(context.Context) (interface{}, error)
}

type CheckerFunc func() (interface{}, error)
type CheckerFunc func(context.Context) (interface{}, error)

func (f CheckerFunc) HealthCheck() (interface{}, error) {
return f()
func (f CheckerFunc) HealthCheck(ctx context.Context) (interface{}, error) {
return f(ctx)
}
11 changes: 6 additions & 5 deletions api/health/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package health

import (
"context"
"time"

"github.com/prometheus/client_golang/prometheus"
Expand All @@ -21,7 +22,7 @@ type Health interface {
Registerer
Reporter

Start(freq time.Duration)
Start(ctx context.Context, freq time.Duration)
Stop()
}

Expand Down Expand Up @@ -108,10 +109,10 @@ func (h *health) Liveness() (map[string]Result, bool) {
return results, healthy
}

func (h *health) Start(freq time.Duration) {
h.readiness.Start(freq)
h.health.Start(freq)
h.liveness.Start(freq)
func (h *health) Start(ctx context.Context, freq time.Duration) {
h.readiness.Start(ctx, freq)
h.health.Start(ctx, freq)
h.liveness.Start(ctx, freq)
}

func (h *health) Stop() {
Expand Down
17 changes: 9 additions & 8 deletions api/health/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package health

import (
"context"
"errors"
"fmt"
"sync"
Expand Down Expand Up @@ -56,7 +57,7 @@ func awaitLiveness(r Reporter, liveness bool) {
func TestDuplicatedRegistations(t *testing.T) {
require := require.New(t)

check := CheckerFunc(func() (interface{}, error) {
check := CheckerFunc(func(context.Context) (interface{}, error) {
return "", nil
})

Expand All @@ -82,7 +83,7 @@ func TestDuplicatedRegistations(t *testing.T) {
func TestDefaultFailing(t *testing.T) {
require := require.New(t)

check := CheckerFunc(func() (interface{}, error) {
check := CheckerFunc(func(context.Context) (interface{}, error) {
return "", nil
})

Expand Down Expand Up @@ -126,7 +127,7 @@ func TestDefaultFailing(t *testing.T) {
func TestPassingChecks(t *testing.T) {
require := require.New(t)

check := CheckerFunc(func() (interface{}, error) {
check := CheckerFunc(func(context.Context) (interface{}, error) {
return "", nil
})

Expand All @@ -140,7 +141,7 @@ func TestPassingChecks(t *testing.T) {
err = h.RegisterLivenessCheck("check", check)
require.NoError(err)

h.Start(checkFreq)
h.Start(context.Background(), checkFreq)
defer h.Stop()

{
Expand Down Expand Up @@ -193,7 +194,7 @@ func TestPassingThenFailingChecks(t *testing.T) {
shouldCheckErr utils.AtomicBool
checkErr = errors.New("unhealthy")
)
check := CheckerFunc(func() (interface{}, error) {
check := CheckerFunc(func(context.Context) (interface{}, error) {
if shouldCheckErr.GetValue() {
return checkErr.Error(), checkErr
}
Expand All @@ -210,7 +211,7 @@ func TestPassingThenFailingChecks(t *testing.T) {
err = h.RegisterLivenessCheck("check", check)
require.NoError(err)

h.Start(checkFreq)
h.Start(context.Background(), checkFreq)
defer h.Stop()

awaitReadiness(h)
Expand Down Expand Up @@ -254,14 +255,14 @@ func TestDeadlockRegression(t *testing.T) {
require.NoError(err)

var lock sync.Mutex
check := CheckerFunc(func() (interface{}, error) {
check := CheckerFunc(func(context.Context) (interface{}, error) {
lock.Lock()
time.Sleep(time.Nanosecond)
lock.Unlock()
return "", nil
})

h.Start(time.Nanosecond)
h.Start(context.Background(), time.Nanosecond)
defer h.Stop()

for i := 0; i < 1000; i++ {
Expand Down
5 changes: 3 additions & 2 deletions api/health/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package health

import (
"context"
"testing"

"github.com/prometheus/client_golang/prometheus"
Expand All @@ -16,7 +17,7 @@ import (
func TestServiceResponses(t *testing.T) {
require := require.New(t)

check := CheckerFunc(func() (interface{}, error) {
check := CheckerFunc(func(context.Context) (interface{}, error) {
return "", nil
})

Expand Down Expand Up @@ -68,7 +69,7 @@ func TestServiceResponses(t *testing.T) {
require.False(reply.Healthy)
}

h.Start(checkFreq)
h.Start(context.Background(), checkFreq)
defer h.Stop()

awaitReadiness(h)
Expand Down
19 changes: 10 additions & 9 deletions api/health/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package health

import (
"context"
"errors"
"fmt"
"sync"
Expand Down Expand Up @@ -60,13 +61,13 @@ func (w *worker) RegisterCheck(name string, checker Checker) error {

func (w *worker) RegisterMonotonicCheck(name string, checker Checker) error {
var result utils.AtomicInterface
return w.RegisterCheck(name, CheckerFunc(func() (interface{}, error) {
return w.RegisterCheck(name, CheckerFunc(func(ctx context.Context) (interface{}, error) {
details := result.GetValue()
if details != nil {
return details, nil
}

details, err := checker.HealthCheck()
details, err := checker.HealthCheck(ctx)
if err == nil {
result.SetValue(details)
}
Expand All @@ -87,17 +88,17 @@ func (w *worker) Results() (map[string]Result, bool) {
return results, healthy
}

func (w *worker) Start(freq time.Duration) {
func (w *worker) Start(ctx context.Context, freq time.Duration) {
w.startOnce.Do(func() {
go func() {
ticker := time.NewTicker(freq)
defer ticker.Stop()

w.runChecks()
w.runChecks(ctx)
for {
select {
case <-ticker.C:
w.runChecks()
w.runChecks(ctx)
case <-w.closer:
return
}
Expand All @@ -112,7 +113,7 @@ func (w *worker) Stop() {
})
}

func (w *worker) runChecks() {
func (w *worker) runChecks(ctx context.Context) {
w.checksLock.RLock()
// Copy the [w.checks] map to collect the checks that we will be running
// during this iteration. If [w.checks] is modified during this iteration of
Expand All @@ -127,20 +128,20 @@ func (w *worker) runChecks() {
var wg sync.WaitGroup
wg.Add(len(checks))
for name, check := range checks {
go w.runCheck(&wg, name, check)
go w.runCheck(ctx, &wg, name, check)
}
wg.Wait()
}

func (w *worker) runCheck(wg *sync.WaitGroup, name string, check Checker) {
func (w *worker) runCheck(ctx context.Context, wg *sync.WaitGroup, name string, check Checker) {
defer wg.Done()

start := time.Now()

// To avoid any deadlocks when [RegisterCheck] is called with a lock
// that is grabbed by [check.HealthCheck], we ensure that no locks
// are held when [check.HealthCheck] is called.
details, err := check.HealthCheck()
details, err := check.HealthCheck(ctx)
end := time.Now()

result := Result{
Expand Down
2 changes: 1 addition & 1 deletion api/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func (s *server) registerChain(chainName string, engine common.Engine) {

ctx := engine.Context()
ctx.Lock.Lock()
handlers, err = engine.GetVM().CreateHandlers()
handlers, err = engine.GetVM().CreateHandlers(context.TODO())
ctx.Lock.Unlock()
if err != nil {
s.log.Error("failed to create handlers",
Expand Down
Loading

0 comments on commit 5be9266

Please sign in to comment.