Skip to content

Commit

Permalink
CNS-213: add support for filters in pairing
Browse files Browse the repository at this point in the history
  • Loading branch information
oren-lava committed Jun 7, 2023
1 parent 54c066b commit 9680bf3
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 88 deletions.
52 changes: 50 additions & 2 deletions x/pairing/keeper/filters/filter.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,68 @@
package filters

import (
"fmt"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/lavanet/lava/utils"
epochstoragetypes "github.com/lavanet/lava/x/epochstorage/types"
projectstypes "github.com/lavanet/lava/x/projects/types"
)

type Filter interface {
Filter(ctx sdk.Context, stakeEntry []epochstoragetypes.StakeEntry) []bool
Filter(ctx sdk.Context, providers []epochstoragetypes.StakeEntry) []bool
InitFilter(strictestPolicy projectstypes.Policy) bool // return if filter is usable (by the policy)
}

func GetAllFilters() []Filter {
var selectedProvidersFilter SelectedProvidersFilter
var frozenProvidersFilter FrozenProvidersFilter
var geolocationFilter GeolocationFilter

filters := []Filter{&selectedProvidersFilter, &frozenProvidersFilter}
filters := []Filter{&selectedProvidersFilter, &frozenProvidersFilter, &geolocationFilter}
return filters
}

func initFilters(filters []Filter, strictestPolicy projectstypes.Policy) []Filter {
activeFilters := []Filter{}

for _, filter := range filters {
active := filter.InitFilter(strictestPolicy)
if active {
activeFilters = append(activeFilters, filter)
}
}

return activeFilters
}

func FilterProviders(ctx sdk.Context, filters []Filter, providers []epochstoragetypes.StakeEntry, strictestPolicy projectstypes.Policy) []epochstoragetypes.StakeEntry {
filters = initFilters(filters, strictestPolicy)

filtersResult := make([]bool, len(providers))
for i := range filtersResult {
filtersResult[i] = true
}

for _, filter := range filters {
res := filter.Filter(ctx, providers)
if len(res) != len(providers) {
utils.LavaFormatError("filter result length is not equal to providers list length", fmt.Errorf("filter failed"),
utils.Attribute{},
)
}

for i := range res {
filtersResult[i] = filtersResult[i] && res[i]
}
}

filteredProviders := []epochstoragetypes.StakeEntry{}
for i := range filtersResult {
if filtersResult[i] {
filteredProviders = append(filteredProviders, providers[i])
}
}

return filteredProviders
}
12 changes: 9 additions & 3 deletions x/pairing/keeper/grpc_query_static_providers_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"
epochstoragetypes "github.com/lavanet/lava/x/epochstorage/types"
pairingfilters "github.com/lavanet/lava/x/pairing/keeper/filters"
"github.com/lavanet/lava/x/pairing/types"
projectstypes "github.com/lavanet/lava/x/projects/types"
spectypes "github.com/lavanet/lava/x/spec/types"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -40,12 +42,16 @@ func (k Keeper) StaticProvidersList(goCtx context.Context, req *types.QueryStati
}

finalProviders := []epochstoragetypes.StakeEntry{}
geolocation := uint64(1)
var geolocationFilter pairingfilters.GeolocationFilter
policy := projectstypes.Policy{
GeolocationProfile: uint64(1),
}
_ = geolocationFilter.InitFilter(policy)
for i := uint64(0); i < k.specKeeper.GeolocationCount(ctx); i++ {
validProviders := k.getUnfrozenGeolocationProviders(ctx, stakes, geolocation)
validProviders := pairingfilters.FilterProviders(ctx, []pairingfilters.Filter{&geolocationFilter}, stakes, policy)
validProviders = k.returnSubsetOfProvidersByHighestStake(ctx, validProviders, servicersToPairCount)
finalProviders = append(finalProviders, validProviders...)
geolocation <<= 1
policy.GeolocationProfile <<= 1
}

return &types.QueryStaticProvidersListResponse{Providers: finalProviders}, nil
Expand Down
111 changes: 28 additions & 83 deletions x/pairing/keeper/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
commontypes "github.com/lavanet/lava/common/types"
"github.com/lavanet/lava/utils"
epochstoragetypes "github.com/lavanet/lava/x/epochstorage/types"
pairingfilters "github.com/lavanet/lava/x/pairing/keeper/filters"
projectstypes "github.com/lavanet/lava/x/projects/types"
spectypes "github.com/lavanet/lava/x/spec/types"
tendermintcrypto "github.com/tendermint/tendermint/crypto"
Expand Down Expand Up @@ -105,12 +106,11 @@ func (k Keeper) GetPairingForClient(ctx sdk.Context, chainID string, clientAddre

// function used to get a new pairing from provider and client
// first argument has all metadata, second argument is only the addresses
func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, clientAddress sdk.AccAddress, block uint64) (providers []epochstoragetypes.StakeEntry, allowedCU uint64, legacyStake bool, errorRet error) {
var geolocation uint64
var providersToPair uint64
func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, clientAddress sdk.AccAddress, block uint64) ([]epochstoragetypes.StakeEntry, uint64, bool, error) {
var allowedCU uint64
var projectToPair string
var selectedProvidersMode projectstypes.PolicySelectedProvidersModeEnum
var selectedProvidersList []string
var legacyStake bool
var strictestPolicy projectstypes.Policy

epoch, err := k.VerifyPairingData(ctx, chainID, clientAddress, block)
if err != nil {
Expand All @@ -120,7 +120,7 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, clientAddre
project, err := k.GetProjectData(ctx, clientAddress, chainID, block)
if err == nil {
legacyStake = false
geolocation, providersToPair, projectToPair, allowedCU, selectedProvidersMode, selectedProvidersList, err = k.getProjectStrictestPolicy(ctx, project, chainID)
strictestPolicy, projectToPair, allowedCU, err = k.getProjectStrictestPolicy(ctx, project, chainID)
if err != nil {
return nil, 0, false, fmt.Errorf("invalid user for pairing: %s", err.Error())
}
Expand All @@ -131,14 +131,14 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, clientAddre
// user is not valid for pairing
return nil, 0, false, fmt.Errorf("invalid user for pairing: 1) %s 2) %s", err.Error(), err2.Error())
}
geolocation = clientStakeEntry.Geolocation
strictestPolicy.GeolocationProfile = clientStakeEntry.Geolocation

servicersToPairCount, err := k.ServicersToPairCount(ctx, block)
if err != nil {
return nil, 0, false, err
}

providersToPair = servicersToPairCount
strictestPolicy.MaxProvidersToPair = servicersToPairCount
projectToPair = clientAddress.String()

allowedCU, err = k.ClientMaxCUProviderForBlock(ctx, block, clientStakeEntry)
Expand All @@ -154,29 +154,20 @@ func (k Keeper) getPairingForClient(ctx sdk.Context, chainID string, clientAddre
return nil, 0, false, fmt.Errorf("did not find providers for pairing: epoch:%d, chainID: %s", block, chainID)
}

switch selectedProvidersMode {
case projectstypes.Policy_EXCLUSIVE, projectstypes.Policy_MIXED:
possibleProviders, err = k.getStakeEntriesOfSelectedProviders(ctx, possibleProviders, selectedProvidersList)
if err != nil {
return nil, 0, false, err
}

if uint64(len(possibleProviders)) <= providersToPair {
return possibleProviders, allowedCU, legacyStake, err
}
}
filters := pairingfilters.GetAllFilters()

// TODO: for mixed mode - change the providersToPair variable
possibleProviders = pairingfilters.FilterProviders(ctx, filters, possibleProviders, strictestPolicy)

providers, err = k.calculatePairingForClient(ctx, possibleProviders, projectToPair, block, chainID, geolocation, epochHash, providersToPair)
providers, err := k.calculatePairingForClient(ctx, possibleProviders, projectToPair, block,
chainID, epochHash, strictestPolicy.MaxProvidersToPair)

return providers, allowedCU, legacyStake, err
}

func (k Keeper) getProjectStrictestPolicy(ctx sdk.Context, project projectstypes.Project, chainID string) (uint64, uint64, string, uint64, projectstypes.PolicySelectedProvidersModeEnum, []string, error) {
func (k Keeper) getProjectStrictestPolicy(ctx sdk.Context, project projectstypes.Project, chainID string) (projectstypes.Policy, string, uint64, error) {
plan, err := k.subscriptionKeeper.GetPlanFromSubscription(ctx, project.GetSubscription())
if err != nil {
return 0, 0, "", 0, 0, []string{}, err
return projectstypes.Policy{}, "", 0, err
}

planPolicy := plan.GetPlanPolicy()
Expand All @@ -189,27 +180,34 @@ func (k Keeper) getProjectStrictestPolicy(ctx sdk.Context, project projectstypes
}

if !projectstypes.CheckChainIdExistsInPolicies(chainID, policies) {
return 0, 0, "", 0, 0, []string{}, fmt.Errorf("chain ID not found in any of the policies")
return projectstypes.Policy{}, "", 0, fmt.Errorf("chain ID not found in any of the policies")
}

geolocation := k.CalculateEffectiveGeolocationFromPolicies(policies)

providersToPair := k.CalculateEffectiveProvidersToPairFromPolicies(policies)
if providersToPair == uint64(math.MaxUint64) {
return 0, 0, "", 0, 0, []string{}, fmt.Errorf("could not calculate providersToPair value: all policies are nil")
return projectstypes.Policy{}, "", 0, fmt.Errorf("could not calculate providersToPair value: all policies are nil")
}

sub, found := k.subscriptionKeeper.GetSubscription(ctx, project.GetSubscription())
if !found {
return 0, 0, "", 0, 0, []string{}, fmt.Errorf("could not find subscription with address %s", project.GetSubscription())
return projectstypes.Policy{}, "", 0, fmt.Errorf("could not find subscription with address %s", project.GetSubscription())
}
allowedCU := k.CalculateEffectiveAllowedCuPerEpochFromPolicies(policies, project.GetUsedCu(), sub.GetMonthCuLeft())

projectToPair := project.Index

selectedProvidersMode, selectedProvidersList := k.CalculateEffectiveSelectedProviders(policies)

return geolocation, providersToPair, projectToPair, allowedCU, selectedProvidersMode, selectedProvidersList, nil
strictestPolicy := projectstypes.Policy{
GeolocationProfile: geolocation,
MaxProvidersToPair: providersToPair,
SelectedProvidersMode: selectedProvidersMode,
SelectedProviders: selectedProvidersList,
}

return strictestPolicy, projectToPair, allowedCU, nil
}

func (k Keeper) CalculateEffectiveSelectedProviders(policies []*projectstypes.Policy) (projectstypes.PolicySelectedProvidersModeEnum, []string) {
Expand All @@ -230,33 +228,6 @@ func (k Keeper) CalculateEffectiveSelectedProviders(policies []*projectstypes.Po
return effectiveMode, effectiveSelectedProviders
}

func (k Keeper) getStakeEntriesOfSelectedProviders(ctx sdk.Context, possibleProviders []epochstoragetypes.StakeEntry, selectedProviders []string) ([]epochstoragetypes.StakeEntry, error) {
if len(selectedProviders) == 0 {
return nil, utils.LavaFormatWarning("selected providers intersection set is empty", fmt.Errorf("no providers to pair"))
}

selectedProvidersMap := map[string]string{}
for _, selectedProviderAddr := range selectedProviders {
selectedProvidersMap[selectedProviderAddr] = ""
}

providers := []epochstoragetypes.StakeEntry{}
for _, providerStakeEntry := range possibleProviders {
_, found := selectedProvidersMap[providerStakeEntry.Address]
if found && !isProviderFrozen(ctx, providerStakeEntry) {
// selected providers are not affected by geolocation -> make them include all (max uint64)
providerStakeEntry.Geolocation = math.MaxUint64
providers = append(providers, providerStakeEntry)
}
}

if len(providers) == 0 {
return nil, utils.LavaFormatWarning("none of the selected providers intersection set is staked", fmt.Errorf("no providers to pair"))
}

return providers, nil
}

func (k Keeper) CalculateEffectiveGeolocationFromPolicies(policies []*projectstypes.Policy) uint64 {
geolocation := uint64(math.MaxUint64)

Expand Down Expand Up @@ -326,7 +297,7 @@ func (k Keeper) ValidatePairingForClient(ctx sdk.Context, chainID string, client
return false, allowedCU, 0, legacyStake, nil
}

func (k Keeper) calculatePairingForClient(ctx sdk.Context, providers []epochstoragetypes.StakeEntry, developerAddress string, epochStartBlock uint64, chainID string, geolocation uint64, epochHash []byte, providersToPair uint64) (validProviders []epochstoragetypes.StakeEntry, err error) {
func (k Keeper) calculatePairingForClient(ctx sdk.Context, providers []epochstoragetypes.StakeEntry, developerAddress string, epochStartBlock uint64, chainID string, epochHash []byte, providersToPair uint64) (validProviders []epochstoragetypes.StakeEntry, err error) {
if epochStartBlock > uint64(ctx.BlockHeight()) {
k.Logger(ctx).Error("\ninvalid session start\n")
panic(fmt.Sprintf("invalid session start saved in keeper %d, current block was %d", epochStartBlock, uint64(ctx.BlockHeight())))
Expand All @@ -337,43 +308,17 @@ func (k Keeper) calculatePairingForClient(ctx sdk.Context, providers []epochstor
return nil, fmt.Errorf("spec not found or not enabled")
}

validProviders = k.getUnfrozenGeolocationProviders(ctx, providers, geolocation)

if spec.ProvidersTypes == spectypes.Spec_dynamic {
// calculates a hash and randomly chooses the providers

validProviders = k.returnSubsetOfProvidersByStake(ctx, developerAddress, validProviders, providersToPair, epochStartBlock, chainID, epochHash)
validProviders = k.returnSubsetOfProvidersByStake(ctx, developerAddress, providers, providersToPair, epochStartBlock, chainID, epochHash)
} else {
validProviders = k.returnSubsetOfProvidersByHighestStake(ctx, validProviders, providersToPair)
validProviders = k.returnSubsetOfProvidersByHighestStake(ctx, providers, providersToPair)
}

return validProviders, nil
}

func (k Keeper) getUnfrozenGeolocationProviders(ctx sdk.Context, providers []epochstoragetypes.StakeEntry, geolocation uint64) []epochstoragetypes.StakeEntry {
validProviders := []epochstoragetypes.StakeEntry{}
// create a list of valid providers (stakeAppliedBlock reached)
for _, stakeEntry := range providers {
if isProviderFrozen(ctx, stakeEntry) {
// provider stakeAppliedBlock wasn't reached yet
continue
}

geolocationSupported := stakeEntry.Geolocation & geolocation
if geolocationSupported == 0 {
// no match in geolocation bitmap
continue
}

validProviders = append(validProviders, stakeEntry)
}
return validProviders
}

func isProviderFrozen(ctx sdk.Context, stakeEntry epochstoragetypes.StakeEntry) bool {
return stakeEntry.StakeAppliedBlock > uint64(ctx.BlockHeight())
}

// this function randomly chooses count providers by weight
func (k Keeper) returnSubsetOfProvidersByStake(ctx sdk.Context, clientAddress string, providersMaps []epochstoragetypes.StakeEntry, count uint64, block uint64, chainID string, epochHash []byte) (returnedProviders []epochstoragetypes.StakeEntry) {
stakeSum := sdk.NewCoin(epochstoragetypes.TokenDenom, sdk.NewInt(0))
Expand Down

0 comments on commit 9680bf3

Please sign in to comment.