Skip to content

Commit

Permalink
simplified API for LoadAll
Browse files Browse the repository at this point in the history
  • Loading branch information
vikstrous committed Aug 6, 2023
1 parent 857dcef commit 27a2990
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 32 deletions.
36 changes: 18 additions & 18 deletions dataloaden_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ func TestUserLoader(t *testing.T) {
if u[3].Name != "user U4" {
t.Fatal("not equal")
}
if err[1] == nil {
if err == nil {
t.Fatal("error expected")
}
if err.(dataloadgen.ErrorSlice)[1] == nil {
t.Fatal("error expected")
}
if err[2] == nil {
if err.(dataloadgen.ErrorSlice)[2] == nil {
t.Fatal("error expected")
}
})
Expand Down Expand Up @@ -137,8 +140,8 @@ func TestUserLoader(t *testing.T) {
t.Run("load many users", func(t *testing.T) {
t.Parallel()
u, err := dl.LoadAll(ctx, []string{"U2", "U4"})
if len(err) != 0 {
t.Fatal("wrong length", err)
if err != nil {
t.Fatal(err)
}
if u[0].Name != "user U2" {
t.Fatal("not equal")
Expand Down Expand Up @@ -179,7 +182,7 @@ func TestUserLoader(t *testing.T) {
if u[1].ID != "U4" {
t.Fatal("not equal")
}
if err[2] == nil {
if err.(dataloadgen.ErrorSlice)[2] == nil {
t.Fatal("error expected")
}
if u[3].ID != "U9" {
Expand Down Expand Up @@ -274,11 +277,8 @@ func TestUserLoader(t *testing.T) {
t.Fatal("wrong length", fetches)
}

if err1[0] != nil {
t.Fatal(err1[0])
}
if err1[1] != nil {
t.Fatal(err1[1])
if err1 != nil {
t.Fatal(err1)
}
if "user U5" != users1[0].Name {
t.Fatal("not equal")
Expand All @@ -293,10 +293,10 @@ func TestUserLoader(t *testing.T) {
t.Fatal("wrong length", fetches)
}

if err2[0] != nil {
t.Fatal(err2[0])
if err2.(dataloadgen.ErrorSlice)[0] != nil {
t.Fatal(err2.(dataloadgen.ErrorSlice)[0])
}
if err2[1] == nil {
if err2.(dataloadgen.ErrorSlice)[1] == nil {
t.Fatal("error expected")
}
if "user U6" != users2[0].Name {
Expand Down Expand Up @@ -332,19 +332,19 @@ func TestUserLoader(t *testing.T) {
t.Fatal("not empty", user)
}
}
if len(errs) != 2 {
if len(errs.(dataloadgen.ErrorSlice)) != 2 {
t.Fatal("wrong length", errs)
}
if errs[0] == nil {
if errs.(dataloadgen.ErrorSlice)[0] == nil {
t.Fatal("error expected")
}
if "failed all fetches" != errs[0].Error() {
if "failed all fetches" != errs.(dataloadgen.ErrorSlice)[0].Error() {
t.Fatal("not equal")
}
if errs[1] == nil {
if errs.(dataloadgen.ErrorSlice)[1] == nil {
t.Fatal("error expected")
}
if "failed all fetches" != errs[1].Error() {
if "failed all fetches" != errs.(dataloadgen.ErrorSlice)[1].Error() {
t.Fatal("not equal")
}
})
Expand Down
20 changes: 12 additions & 8 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

// copied and adapted from github.com/graph-gophers/dataloader
func BenchmarkLoaderFromDataloader(b *testing.B) {
var a = &Avg{}
a := &Avg{}
ctx := context.Background()
dl := dataloadgen.NewLoader(func(keys []string) (results []string, errs []error) {
a.Add(len(keys))
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestLoader(t *testing.T) {
t.Parallel()
errorLoader, _ := ErrorLoader(0)
_, err := errorLoader.LoadAll(ctx, []string{"1", "2", "3"})
if len(err) != 3 {
if len(err.(dataloadgen.ErrorSlice)) != 3 {
t.Error("LoadAll didn't return right number of errors")
}
})
Expand All @@ -128,13 +128,13 @@ func TestLoader(t *testing.T) {
t.Parallel()
loader, _ := OneErrorLoader(3)
_, errs := loader.LoadAll(ctx, []string{"1", "2", "3"})
if len(errs) != 3 {
if len(errs.(dataloadgen.ErrorSlice)) != 3 {
t.Errorf("LoadAll didn't return right number of errors (should match size of input)")
}

var errCount int = 0
var nilCount int = 0
for _, err := range errs {
for _, err := range errs.(dataloadgen.ErrorSlice) {
if err == nil {
nilCount++
} else {
Expand Down Expand Up @@ -177,8 +177,8 @@ func TestLoader(t *testing.T) {
}
}()
panicLoader, _ := PanicLoader(0)
_, errs := panicLoader.LoadAll(ctx, []string{"1"})
if len(errs) < 1 || errs[0].Error() != "Panic received in batch function: Programming error" {
_, err := panicLoader.LoadAll(ctx, []string{"1"})
if err == nil || err.Error() != "Panic received in batch function: Programming error" {
t.Error("Panic was not propagated as an error.")
}
})
Expand Down Expand Up @@ -508,9 +508,10 @@ func BatchOnlyLoader(max int) (*dataloadgen.Loader[string, string], *[][]string)
results = append(results, key)
}
return results, nil
}, dataloadgen.WithBatchCapacity(max)) //dataloadgen.WithClearCacheOnBatch())
}, dataloadgen.WithBatchCapacity(max)) // dataloadgen.WithClearCacheOnBatch())
return identityLoader, &loadCalls
}

func ErrorLoader(max int) (*dataloadgen.Loader[string, string], *[][]string) {
var mu sync.Mutex
var loadCalls [][]string
Expand All @@ -526,6 +527,7 @@ func ErrorLoader(max int) (*dataloadgen.Loader[string, string], *[][]string) {
}, dataloadgen.WithBatchCapacity(max))
return identityLoader, &loadCalls
}

func OneErrorLoader(max int) (*dataloadgen.Loader[string, string], *[][]string) {
var mu sync.Mutex
var loadCalls [][]string
Expand All @@ -547,13 +549,15 @@ func OneErrorLoader(max int) (*dataloadgen.Loader[string, string], *[][]string)
}, dataloadgen.WithBatchCapacity(max))
return identityLoader, &loadCalls
}

func PanicLoader(max int) (*dataloadgen.Loader[string, string], *[][]string) {
var loadCalls [][]string
panicLoader := dataloadgen.NewLoader(func(keys []string) (results []string, errs []error) {
panic("Programming error")
}, dataloadgen.WithBatchCapacity(max)) //, withSilentLogger())
return panicLoader, &loadCalls
}

func BadLoader(max int) (*dataloadgen.Loader[string, string], *[][]string) {
var mu sync.Mutex
var loadCalls [][]string
Expand All @@ -570,7 +574,7 @@ func BadLoader(max int) (*dataloadgen.Loader[string, string], *[][]string) {
func NoCacheLoader(max int) (*dataloadgen.Loader[string, string], *[][]string) {
var mu sync.Mutex
var loadCalls [][]string
//cache := &NoCache{}
// cache := &NoCache{}
identityLoader := dataloadgen.NewLoader(func(keys []string) (results []string, errs []error) {
mu.Lock()
loadCalls = append(loadCalls, keys)
Expand Down
32 changes: 26 additions & 6 deletions dataloadgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dataloadgen

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -37,7 +38,7 @@ func WithTracer(tracer trace.Tracer) Option {
func NewLoader[KeyT comparable, ValueT any](fetch func(keys []KeyT) ([]ValueT, []error), options ...Option) *Loader[KeyT, ValueT] {
config := &loaderConfig{
wait: 16 * time.Millisecond,
maxBatch: 0, //unlimited
maxBatch: 0, // unlimited
}
for _, o := range options {
o(config)
Expand Down Expand Up @@ -148,9 +149,21 @@ func (l *Loader[KeyT, ValueT]) LoadThunk(ctx context.Context, key KeyT) func() (
return thunk
}

// ErrorSlice represents a list of errors that contains at least one error
type ErrorSlice []error

// Error implements the error interface
func (e ErrorSlice) Error() string {
combinedErr := errors.Join([]error(e)...)
if combinedErr == nil {
return "no error data"
}
return combinedErr.Error()
}

// LoadAll fetches many keys at once. It will be broken into appropriate sized
// sub batches depending on how the loader is configured
func (l *Loader[KeyT, ValueT]) LoadAll(ctx context.Context, keys []KeyT) ([]ValueT, []error) {
func (l *Loader[KeyT, ValueT]) LoadAll(ctx context.Context, keys []KeyT) ([]ValueT, error) {
thunks := make([]func() (ValueT, error), len(keys))

for i, key := range keys {
Expand All @@ -169,24 +182,31 @@ func (l *Loader[KeyT, ValueT]) LoadAll(ctx context.Context, keys []KeyT) ([]Valu
if allNil {
return values, nil
}
return values, errors
return values, ErrorSlice(errors)
}

// LoadAllThunk returns a function that when called will block waiting for a ValueT.
// This method should be used if you want one goroutine to make requests to many
// different data loaders without blocking until the thunk is called.
func (l *Loader[KeyT, ValueT]) LoadAllThunk(ctx context.Context, keys []KeyT) func() ([]ValueT, []error) {
func (l *Loader[KeyT, ValueT]) LoadAllThunk(ctx context.Context, keys []KeyT) func() ([]ValueT, error) {
thunks := make([]func() (ValueT, error), len(keys))
for i, key := range keys {
thunks[i] = l.LoadThunk(ctx, key)
}
return func() ([]ValueT, []error) {
return func() ([]ValueT, error) {
values := make([]ValueT, len(keys))
errors := make([]error, len(keys))
allNil := true
for i, thunk := range thunks {
values[i], errors[i] = thunk()
if allNil == true && errors[i] != nil {
allNil = false
}
}
if allNil {
return values, nil
}
return values, errors
return values, ErrorSlice(errors)
}
}

Expand Down

0 comments on commit 27a2990

Please sign in to comment.