diff --git a/service/share/share.go b/service/share/share.go index 62bb4a1ceb..515a3fb9e4 100644 --- a/service/share/share.go +++ b/service/share/share.go @@ -5,6 +5,8 @@ import ( "fmt" "math/rand" + "golang.org/x/sync/errgroup" + "github.com/ipfs/go-cid" format "github.com/ipfs/go-ipld-format" logging "github.com/ipfs/go-log/v2" @@ -161,34 +163,25 @@ func (s *service) GetSharesByNamespace(ctx context.Context, root *Root, nID name return nil, ipld.ErrNotFoundInRange } - type res struct { - nodes []format.Node - err error + errGroup, ctx := errgroup.WithContext(ctx) + nodes := make([][]format.Node, len(rowRootCIDs)) + for i, rootCID := range rowRootCIDs { + // shadow loop variables, to ensure correct values are captured + i, rootCID := i, rootCID + errGroup.Go(func() (err error) { + nodes[i], err = ipld.GetLeavesByNamespace(ctx, s.dag, rootCID, nID) + return + }) } - resultCh := make(chan *res) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - for _, rootCID := range rowRootCIDs { - go func(rootCID cid.Cid) { - nodes, err := ipld.GetLeavesByNamespace(ctx, s.dag, rootCID, nID) - resultCh <- &res{nodes: nodes, err: err} - }(rootCID) + if err := errGroup.Wait(); err != nil { + return nil, err } namespacedShares := make([]Share, 0) for i := 0; i < len(rowRootCIDs); i++ { - select { - case result := <-resultCh: - if result.err != nil { - return nil, result.err - } - for _, node := range result.nodes { - namespacedShares = append(namespacedShares, node.RawData()[1:]) - } - case <-ctx.Done(): - return nil, ctx.Err() + for _, node := range nodes[i] { + namespacedShares = append(namespacedShares, node.RawData()[1:]) } } diff --git a/service/share/share_test.go b/service/share/share_test.go index a618221d65..5f8d831c20 100644 --- a/service/share/share_test.go +++ b/service/share/share_test.go @@ -35,22 +35,27 @@ func TestGetShare(t *testing.T) { func TestService_GetSharesByNamespace(t *testing.T) { var tests = []struct { - amountShares int + squareSize int expectedShareCount int }{ - {amountShares: 4, expectedShareCount: 1}, - {amountShares: 16, expectedShareCount: 2}, - {amountShares: 128, expectedShareCount: 1}, + {squareSize: 4, expectedShareCount: 1}, + {squareSize: 16, expectedShareCount: 2}, + {squareSize: 128, expectedShareCount: 1}, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { - serv, root := RandLightServiceWithSquare(t, tt.amountShares) - randNID := root.RowsRoots[(len(root.RowsRoots)-1)/2][:8] + serv, dag := RandLightService() + n := tt.squareSize * tt.squareSize + randShares := RandShares(t, n) + idx1 := (n - 1) / 2 + idx2 := n / 2 if tt.expectedShareCount > 1 { // make it so that two rows have the same namespace ID - root.RowsRoots[(len(root.RowsRoots) / 2)] = root.RowsRoots[(len(root.RowsRoots)-1)/2] + copy(randShares[idx2][:8], randShares[idx1][:8]) } + root := FillDag(t, tt.squareSize, dag, randShares) + randNID := []byte(randShares[idx1][:8]) shares, err := serv.GetSharesByNamespace(context.Background(), root, randNID) require.NoError(t, err) @@ -58,6 +63,11 @@ func TestService_GetSharesByNamespace(t *testing.T) { for _, value := range shares { assert.Equal(t, randNID, []byte(value.NamespaceID())) } + if tt.expectedShareCount > 1 { + // idx1 is always smaller than idx2 + assert.Equal(t, []byte(randShares[idx1]), shares[0].Data()) + assert.Equal(t, []byte(randShares[idx2]), shares[1].Data()) + } }) } } diff --git a/service/share/testing.go b/service/share/testing.go index a60bbce8bf..58996f29e5 100644 --- a/service/share/testing.go +++ b/service/share/testing.go @@ -35,6 +35,13 @@ func RandLightServiceWithSquare(t *testing.T, n int) (Service, *Root) { return NewService(dag, NewLightAvailability(dag)), RandFillDAG(t, n, dag) } +// RandLightService provides an unfilled share.Service with corresponding +// format.DAGService than can be filled by the test. +func RandLightService() (Service, format.DAGService) { + dag := mdutils.Mock() + return NewService(dag, NewLightAvailability(dag)), dag +} + // RandFullServiceWithSquare provides a share.Service filled with 'n' NMT // trees of 'n' random shares, essentially storing a whole square. func RandFullServiceWithSquare(t *testing.T, n int) (Service, *Root) { @@ -44,6 +51,10 @@ func RandFullServiceWithSquare(t *testing.T, n int) (Service, *Root) { func RandFillDAG(t *testing.T, n int, dag format.DAGService) *Root { shares := RandShares(t, n*n) + return FillDag(t, n, dag, shares) +} + +func FillDag(t *testing.T, n int, dag format.DAGService, shares []Share) *Root { sharesSlices := make([][]byte, n*n) for i, share := range shares { sharesSlices[i] = share