diff --git a/.golangci.yml b/.golangci.yml index 8186e95a44..28288d72f8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -65,7 +65,7 @@ issues: linters-settings: dogsled: - max-blank-identifiers: 3 + max-blank-identifiers: 5 maligned: # print struct with more effective memory layout or not, false by default suggest-new: true diff --git a/x/conflict/keeper/msg_server_detection_test.go b/x/conflict/keeper/msg_server_detection_test.go index 67396fb5da..959283e587 100644 --- a/x/conflict/keeper/msg_server_detection_test.go +++ b/x/conflict/keeper/msg_server_detection_test.go @@ -31,9 +31,10 @@ func setupForConflictTests(t *testing.T, numOfProviders int) testStruct { // init keepers state var balance int64 = 100000 + // setup consumer ts.consumer = common.CreateNewAccount(ts.ctx, *ts.keepers, balance) - // setup consumer + // setup providers for i := 0; i < numOfProviders; i++ { ts.Providers = append(ts.Providers, common.CreateNewAccount(ts.ctx, *ts.keepers, balance)) } diff --git a/x/pairing/keeper/pairing.go b/x/pairing/keeper/pairing.go index bb4dfff34d..8de8ed5e57 100644 --- a/x/pairing/keeper/pairing.go +++ b/x/pairing/keeper/pairing.go @@ -161,7 +161,14 @@ func (k Keeper) getProjectStrictestPolicy(ctx sdk.Context, project projectstypes } planPolicy := plan.GetPlanPolicy() - policies := []*projectstypes.Policy{project.AdminPolicy, project.SubscriptionPolicy, &planPolicy} + policies := []*projectstypes.Policy{&planPolicy} + if project.SubscriptionPolicy != nil { + policies = append(policies, project.SubscriptionPolicy) + } + if project.AdminPolicy != nil { + policies = append(policies, project.AdminPolicy) + } + if !projectstypes.CheckChainIdExistsInPolicies(chainID, policies) { return 0, 0, "", 0, fmt.Errorf("chain ID not found in any of the policies") } @@ -169,6 +176,9 @@ func (k Keeper) getProjectStrictestPolicy(ctx sdk.Context, project projectstypes geolocation := k.CalculateEffectiveGeolocationFromPolicies(policies) providersToPair := k.CalculateEffectiveProvidersToPairFromPolicies(policies) + if providersToPair == uint64(math.MaxUint64) { + return 0, 0, "", 0, fmt.Errorf("could not calculate providersToPair value: all policies are nil") + } sub, found := k.subscriptionKeeper.GetSubscription(ctx, project.GetSubscription()) if !found { @@ -194,15 +204,16 @@ func (k Keeper) CalculateEffectiveGeolocationFromPolicies(policies []*projectsty } func (k Keeper) CalculateEffectiveProvidersToPairFromPolicies(policies []*projectstypes.Policy) uint64 { - var providersToPairValues []uint64 + providersToPair := uint64(math.MaxUint64) for _, policy := range policies { - if policy != nil { - providersToPairValues = append(providersToPairValues, policy.GetMaxProvidersToPair()) + val := policy.GetMaxProvidersToPair() + if policy != nil && val < providersToPair { + providersToPair = val } } - return commontypes.FindMin(providersToPairValues) + return providersToPair } func (k Keeper) CalculateEffectiveAllowedCuPerEpochFromPolicies(policies []*projectstypes.Policy, cuUsedInProject uint64, cuLeftInSubscription uint64) uint64 { diff --git a/x/pairing/keeper/pairing_subscription_test.go b/x/pairing/keeper/pairing_subscription_test.go index f4e902a82c..a1fa68e9cd 100644 --- a/x/pairing/keeper/pairing_subscription_test.go +++ b/x/pairing/keeper/pairing_subscription_test.go @@ -121,30 +121,59 @@ func TestRelayPaymentSubscription(t *testing.T) { func TestRelayPaymentSubscriptionCU(t *testing.T) { ts := setupForPaymentTest(t) var balance int64 = 10000 - consumer := common.CreateNewAccount(ts.ctx, *ts.keepers, balance) - _, err := ts.servers.SubscriptionServer.Buy(ts.ctx, &subtypes.MsgBuy{Creator: consumer.Addr.String(), Consumer: consumer.Addr.String(), Index: ts.plan.Index, Duration: 1}) - require.Nil(t, err) + consumerA := common.CreateNewAccount(ts.ctx, *ts.keepers, balance) + consumerB := common.CreateNewAccount(ts.ctx, *ts.keepers, balance) - ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + consumers := []common.Account{consumerA, consumerB} - pairingReq := types.QueryGetPairingRequest{ChainID: ts.spec.Index, Client: consumer.Addr.String()} - pairing, err := ts.keepers.Pairing.GetPairing(ts.ctx, &pairingReq) + _, err := ts.servers.SubscriptionServer.Buy(ts.ctx, &subtypes.MsgBuy{Creator: consumerA.Addr.String(), Consumer: consumerA.Addr.String(), Index: ts.plan.Index, Duration: 1}) require.Nil(t, err) - verifyPairingQuery := &types.QueryVerifyPairingRequest{ChainID: ts.spec.Index, Client: consumer.Addr.String(), Provider: pairing.Providers[0].Address, Block: uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())} - vefiry, err := ts.keepers.Pairing.VerifyPairing(ts.ctx, verifyPairingQuery) + consumerBProjectData := projectstypes.ProjectData{ + Name: "consumerBProject", + Description: "", + Enabled: true, + ProjectKeys: []projectstypes.ProjectKey{{ + Key: consumerB.Addr.String(), + Types: []projectstypes.ProjectKey_KEY_TYPE{ + projectstypes.ProjectKey_ADMIN, + projectstypes.ProjectKey_DEVELOPER, + }, + Vrfpk: "", + }}, + Policy: &ts.plan.PlanPolicy, + } + err = ts.keepers.Subscription.AddProjectToSubscription(sdk.UnwrapSDKContext(ts.ctx), consumerA.Addr.String(), consumerBProjectData) require.Nil(t, err) - require.True(t, vefiry.Valid) - _, _, err = ts.keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ts.ctx), consumer.Addr.String(), uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) + // verify both projects exist + projA, _, err := ts.keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ts.ctx), consumerA.Addr.String(), uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) + require.Nil(t, err) + projB, _, err := ts.keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ts.ctx), consumerB.Addr.String(), uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) require.Nil(t, err) + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + // verify that both consumers are paired + for _, consumer := range consumers { + pairingReq := types.QueryGetPairingRequest{ChainID: ts.spec.Index, Client: consumer.Addr.String()} + pairing, err := ts.keepers.Pairing.GetPairing(ts.ctx, &pairingReq) + require.Nil(t, err) + + verifyPairingQuery := &types.QueryVerifyPairingRequest{ChainID: ts.spec.Index, Client: consumer.Addr.String(), Provider: pairing.Providers[0].Address, Block: uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())} + verify, err := ts.keepers.Pairing.VerifyPairing(ts.ctx, verifyPairingQuery) + require.Nil(t, err) + require.True(t, verify.Valid) + } + + // both projects have adminPolicy, subscriptionPolicy = nil -> they go by the plan policy + // waste all the subscription's CU on project A i := 0 for ; uint64(i) < ts.plan.PlanPolicy.GetTotalCuLimit()/ts.plan.PlanPolicy.GetEpochCuLimit(); i++ { relayRequest := common.BuildRelayRequest(ts.ctx, ts.providers[0].Addr.String(), []byte(ts.spec.Apis[0].Name), ts.plan.PlanPolicy.GetEpochCuLimit(), ts.spec.Name, nil) relayRequest.SessionId = uint64(i) - relayRequest.Sig, err = sigs.SignRelay(consumer.SK, *relayRequest) + relayRequest.Sig, err = sigs.SignRelay(consumerA.SK, *relayRequest) require.Nil(t, err) _, err = ts.servers.PairingServer.RelayPayment(ts.ctx, &types.MsgRelayPayment{Creator: ts.providers[0].Addr.String(), Relays: []*types.RelaySession{relayRequest}}) require.Nil(t, err) @@ -152,11 +181,463 @@ func TestRelayPaymentSubscriptionCU(t *testing.T) { ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) } - // last iteration should finish the plan quota + // last iteration should finish the plan and subscription quota relayRequest := common.BuildRelayRequest(ts.ctx, ts.providers[0].Addr.String(), []byte(ts.spec.Apis[0].Name), ts.plan.PlanPolicy.GetEpochCuLimit(), ts.spec.Name, nil) relayRequest.SessionId = uint64(i + 1) - relayRequest.Sig, err = sigs.SignRelay(consumer.SK, *relayRequest) + relayRequest.Sig, err = sigs.SignRelay(consumerA.SK, *relayRequest) + require.Nil(t, err) + _, err = ts.servers.PairingServer.RelayPayment(ts.ctx, &types.MsgRelayPayment{Creator: ts.providers[0].Addr.String(), Relays: []*types.RelaySession{relayRequest}}) + require.NotNil(t, err) + + // verify that project A wasted all of the subscription's CU + sub, found := ts.keepers.Subscription.GetSubscription(sdk.UnwrapSDKContext(ts.ctx), projA.Subscription) + require.True(t, found) + require.Equal(t, uint64(0), sub.MonthCuLeft) + projA, _, err = ts.keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ts.ctx), consumerA.Addr.String(), uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) + require.Nil(t, err) + require.Equal(t, sub.MonthCuTotal, projA.UsedCu) + require.Equal(t, uint64(0), projB.UsedCu) + + // try to use CU on projB. Should fail because A wasted it all + relayRequest.SessionId += 1 + relayRequest.Sig, err = sigs.SignRelay(consumerB.SK, *relayRequest) require.Nil(t, err) _, err = ts.servers.PairingServer.RelayPayment(ts.ctx, &types.MsgRelayPayment{Creator: ts.providers[0].Addr.String(), Relays: []*types.RelaySession{relayRequest}}) require.NotNil(t, err) } + +func TestStrictestPolicyGeolocation(t *testing.T) { + ts := setupForPaymentTest(t) + + // make the plan policy's geolocation 7(=111) + ts.plan.PlanPolicy.GeolocationProfile = 7 + err := ts.keepers.Plans.AddPlan(sdk.UnwrapSDKContext(ts.ctx), ts.plan) + require.Nil(t, err) + + err = ts.keepers.Subscription.CreateSubscription(sdk.UnwrapSDKContext(ts.ctx), + ts.clients[0].Addr.String(), ts.clients[0].Addr.String(), ts.plan.Index, 10, "") + require.Nil(t, err) + + proj, _, err := ts.keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ts.ctx), + ts.clients[0].Addr.String(), uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) + require.Nil(t, err) + + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + geolocationTestTemplates := []struct { + name string + geolocationAdminPolicy uint64 + geolocationSubPolicy uint64 + success bool + }{ + {"effective geo = 1", uint64(1), uint64(1), true}, + {"effective geo = 3 (includes geo=1)", uint64(3), uint64(3), true}, + {"effective geo = 2", uint64(3), uint64(2), false}, + {"effective geo = 0 (planPolicy & subPolicy = 1)", uint64(2), uint64(1), false}, + {"effective geo = 0 (planPolicy & adminPolicy = 1)", uint64(1), uint64(2), false}, + } + + for _, tt := range geolocationTestTemplates { + t.Run(tt.name, func(t *testing.T) { + adminPolicy := &projectstypes.Policy{ + GeolocationProfile: tt.geolocationAdminPolicy, + } + subscriptionPolicy := &projectstypes.Policy{ + GeolocationProfile: tt.geolocationSubPolicy, + } + + _, err = ts.servers.ProjectServer.SetAdminPolicy(ts.ctx, &projectstypes.MsgSetAdminPolicy{ + Creator: ts.clients[0].Addr.String(), + Project: proj.Index, + Policy: *adminPolicy, + }) + require.Nil(t, err) + + // apply the policy setting + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + _, err = ts.servers.ProjectServer.SetSubscriptionPolicy(ts.ctx, &projectstypes.MsgSetSubscriptionPolicy{ + Creator: ts.clients[0].Addr.String(), + Projects: []string{proj.Index}, + Policy: *subscriptionPolicy, + }) + require.Nil(t, err) + + // apply the policy setting + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + // the only provider is set with geolocation=1. So only geolocation that ANDs + // with 1 and output non-zero result, will output a provider for pairing + getPairingResponse, err := ts.keepers.Pairing.GetPairing(ts.ctx, &types.QueryGetPairingRequest{ + ChainID: ts.spec.Index, + Client: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + if tt.success { + require.NotEqual(t, 0, len(getPairingResponse.Providers)) + } else { + require.Equal(t, 0, len(getPairingResponse.Providers)) + } + }) + } +} + +func TestStrictestPolicyProvidersToPair(t *testing.T) { + ts := setupForPaymentTest(t) + + // add 5 more providers so we can have enough providers for testing + ts.addProvider(5) + + _, err := ts.servers.SubscriptionServer.Buy(ts.ctx, &subtypes.MsgBuy{ + Creator: ts.clients[0].Addr.String(), + Consumer: ts.clients[0].Addr.String(), + Index: ts.plan.Index, + Duration: 10, + Vrfpk: "", + }) + require.Nil(t, err) + + developerQueryResponse, err := ts.keepers.Projects.Developer(ts.ctx, &projectstypes.QueryDeveloperRequest{ + Developer: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + proj := developerQueryResponse.Project + + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + providersToPairTestTemplates := []struct { + name string + providersToPairAdminPolicy uint64 + providersToPairSubPolicy uint64 + effectiveProvidersToPair int + }{ + {"effective providersToPair = 2", uint64(4), uint64(2), 2}, + {"sub policy providersToPair = 1", uint64(1), uint64(3), 3}, + {"admin policy providersToPair = 1", uint64(3), uint64(1), 3}, + } + + for _, tt := range providersToPairTestTemplates { + t.Run(tt.name, func(t *testing.T) { + adminPolicy := &projectstypes.Policy{ + GeolocationProfile: 1, + MaxProvidersToPair: tt.providersToPairAdminPolicy, + } + subscriptionPolicy := &projectstypes.Policy{ + GeolocationProfile: 1, + MaxProvidersToPair: tt.providersToPairSubPolicy, + } + + _, err = ts.servers.ProjectServer.SetAdminPolicy(ts.ctx, &projectstypes.MsgSetAdminPolicy{ + Creator: ts.clients[0].Addr.String(), + Project: proj.Index, + Policy: *adminPolicy, + }) + require.Nil(t, err) + + // apply the policy setting + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + _, err = ts.servers.ProjectServer.SetSubscriptionPolicy(ts.ctx, &projectstypes.MsgSetSubscriptionPolicy{ + Creator: ts.clients[0].Addr.String(), + Projects: []string{proj.Index}, + Policy: *subscriptionPolicy, + }) + require.Nil(t, err) + + // apply the policy setting + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + getPairingResponse, err := ts.keepers.Pairing.GetPairing(ts.ctx, &types.QueryGetPairingRequest{ + ChainID: ts.spec.Index, + Client: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + require.Equal(t, tt.effectiveProvidersToPair, len(getPairingResponse.Providers)) + }) + } +} + +func TestStrictestPolicyCuPerEpoch(t *testing.T) { + ts := setupForPaymentTest(t) + + _, err := ts.servers.SubscriptionServer.Buy(ts.ctx, &subtypes.MsgBuy{ + Creator: ts.clients[0].Addr.String(), + Consumer: ts.clients[0].Addr.String(), + Index: ts.plan.Index, + Duration: 10, + Vrfpk: "", + }) + require.Nil(t, err) + + developerQueryResponse, err := ts.keepers.Projects.Developer(ts.ctx, &projectstypes.QueryDeveloperRequest{ + Developer: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + proj := developerQueryResponse.Project + + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + providersToPairTestTemplates := []struct { + name string + cuPerEpochAdminPolicy uint64 + cuPerEpochSubPolicy uint64 + useMostOfProjectCu bool + wasteSubscriptionCu bool + effectiveCuPerEpochLimit uint64 + }{ + {"admin policy min CU", uint64(90), uint64(110), false, false, uint64(90)}, + {"sub policy min CU", uint64(110), uint64(90), false, false, uint64(90)}, + {"use most of the project's CU", uint64(100), uint64(100), true, false, uint64(10)}, + {"waste subscription CU", uint64(100), uint64(100), false, true, uint64(0)}, + } + + for _, tt := range providersToPairTestTemplates { + t.Run(tt.name, func(t *testing.T) { + consumer := ts.clients[0] + + // add a new project to the subscription just to waste the subcsription's cu + if tt.wasteSubscriptionCu { + err = ts.addClient(1) + require.Nil(t, err) + + consumerToWasteCu := ts.clients[1] + + // pair new client with provider + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + projectData := projectstypes.ProjectData{ + Name: "lowCuProject", + Description: "project with low CU limit (per epoch)", + Enabled: true, + ProjectKeys: []projectstypes.ProjectKey{{ + Key: consumerToWasteCu.Addr.String(), + Types: []projectstypes.ProjectKey_KEY_TYPE{ + projectstypes.ProjectKey_DEVELOPER, + projectstypes.ProjectKey_ADMIN, + }, + Vrfpk: "", + }}, + Policy: &ts.plan.PlanPolicy, + } + _, err = ts.servers.SubscriptionServer.AddProject(ts.ctx, &subtypes.MsgAddProject{ + Creator: proj.Subscription, + ProjectData: projectData, + }) + require.Nil(t, err) + + sub, found := ts.keepers.Subscription.GetSubscription(sdk.UnwrapSDKContext(ts.ctx), proj.Subscription) + require.True(t, found) + + relayRequest := common.BuildRelayRequest(ts.ctx, ts.providers[0].Addr.String(), []byte(ts.spec.Apis[0].Name), sub.MonthCuLeft, ts.spec.Name, nil) + relayRequest.SessionId = uint64(100) + relayRequest.Sig, err = sigs.SignRelay(consumerToWasteCu.SK, *relayRequest) + require.Nil(t, err) + _, err = ts.servers.PairingServer.RelayPayment(ts.ctx, &types.MsgRelayPayment{Creator: ts.providers[0].Addr.String(), Relays: []*types.RelaySession{relayRequest}}) + require.Nil(t, err) + + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + } + + adminPolicy := &projectstypes.Policy{ + GeolocationProfile: 1, + EpochCuLimit: tt.cuPerEpochAdminPolicy, + TotalCuLimit: 1000, + } + subscriptionPolicy := &projectstypes.Policy{ + GeolocationProfile: 1, + EpochCuLimit: tt.cuPerEpochSubPolicy, + TotalCuLimit: 1000, + } + + _, err = ts.servers.ProjectServer.SetAdminPolicy(ts.ctx, &projectstypes.MsgSetAdminPolicy{ + Creator: consumer.Addr.String(), + Project: proj.Index, + Policy: *adminPolicy, + }) + require.Nil(t, err) + + // apply the policy setting + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + _, err = ts.servers.ProjectServer.SetSubscriptionPolicy(ts.ctx, &projectstypes.MsgSetSubscriptionPolicy{ + Creator: ts.clients[0].Addr.String(), + Projects: []string{proj.Index}, + Policy: *subscriptionPolicy, + }) + require.Nil(t, err) + + // apply the policy setting + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + // leave 10 CU in the project + if tt.useMostOfProjectCu { + for i := 0; uint64(i) < adminPolicy.TotalCuLimit/ts.plan.PlanPolicy.GetEpochCuLimit(); i++ { + cuSum := ts.plan.PlanPolicy.GetEpochCuLimit() + + developerQueryResponse, err := ts.keepers.Projects.Developer(ts.ctx, &projectstypes.QueryDeveloperRequest{ + Developer: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + proj := developerQueryResponse.Project + if proj.UsedCu >= 900 { + cuSum = 90 + } + + relayRequest := common.BuildRelayRequest(ts.ctx, ts.providers[0].Addr.String(), []byte(ts.spec.Apis[0].Name), cuSum, ts.spec.Name, nil) + relayRequest.SessionId = uint64(i) + relayRequest.Sig, err = sigs.SignRelay(consumer.SK, *relayRequest) + require.Nil(t, err) + _, err = ts.servers.PairingServer.RelayPayment(ts.ctx, &types.MsgRelayPayment{Creator: ts.providers[0].Addr.String(), Relays: []*types.RelaySession{relayRequest}}) + require.Nil(t, err) + + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + } + } + + _, _, _, cuPerEpochLimit, _, _, err := ts.keepers.Pairing.ValidatePairingForClient(sdk.UnwrapSDKContext(ts.ctx), ts.spec.Index, + consumer.Addr, ts.providers[0].Addr, uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) + require.Nil(t, err) + + require.Equal(t, tt.effectiveCuPerEpochLimit, cuPerEpochLimit) + }) + } +} + +func TestPairingNotChangingDueToCuOveruse(t *testing.T) { + ts := setupForPaymentTest(t) + err := ts.addProvider(100) + require.Nil(t, err) + + // advance epoch to get pairing + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + _, err = ts.servers.SubscriptionServer.Buy(ts.ctx, &subtypes.MsgBuy{ + Creator: ts.clients[0].Addr.String(), + Consumer: ts.clients[0].Addr.String(), + Index: ts.plan.Index, + Duration: 11, + Vrfpk: "", + }) + require.Nil(t, err) + + for i := 0; i < int(ts.plan.PlanPolicy.TotalCuLimit)/int(ts.plan.PlanPolicy.EpochCuLimit); i++ { + res, err := ts.keepers.Pairing.GetPairing(ts.ctx, &types.QueryGetPairingRequest{ + ChainID: ts.spec.Index, + Client: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + + cuSum := ts.plan.PlanPolicy.GetEpochCuLimit() + relayRequest := common.BuildRelayRequest(ts.ctx, res.Providers[0].Address, []byte(ts.spec.Apis[0].Name), cuSum, ts.spec.Name, nil) + relayRequest.SessionId = uint64(i) + relayRequest.Sig, err = sigs.SignRelay(ts.clients[0].SK, *relayRequest) + require.Nil(t, err) + _, err = ts.servers.PairingServer.RelayPayment(ts.ctx, &types.MsgRelayPayment{Creator: res.Providers[0].Address, Relays: []*types.RelaySession{relayRequest}}) + require.Nil(t, err) + + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + } + + res, err := ts.keepers.Pairing.GetPairing(ts.ctx, &types.QueryGetPairingRequest{ + ChainID: ts.spec.Index, + Client: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + firstPairing := res.Providers + + // advance an epoch block by block. On each one try to spend more than it's allowed and check the pairing hasn't changed + epochBlocks := ts.keepers.Epochstorage.EpochBlocksRaw(sdk.UnwrapSDKContext(ts.ctx)) + for i := 0; i < int(epochBlocks)-1; i++ { + ts.ctx = testkeeper.AdvanceBlock(ts.ctx, ts.keepers) + + res, err := ts.keepers.Pairing.GetPairing(ts.ctx, &types.QueryGetPairingRequest{ + ChainID: ts.spec.Index, + Client: ts.clients[0].Addr.String(), + }) + require.Nil(t, err) + + cuSum := ts.plan.PlanPolicy.GetEpochCuLimit() + relayRequest := common.BuildRelayRequest(ts.ctx, res.Providers[0].Address, []byte(ts.spec.Apis[0].Name), cuSum, ts.spec.Name, nil) + relayRequest.SessionId = uint64(i) + relayRequest.Sig, err = sigs.SignRelay(ts.clients[0].SK, *relayRequest) + require.Nil(t, err) + _, err = ts.servers.PairingServer.RelayPayment(ts.ctx, &types.MsgRelayPayment{Creator: res.Providers[0].Address, Relays: []*types.RelaySession{relayRequest}}) + require.NotNil(t, err) + + require.Equal(t, firstPairing, res.Providers) + } +} + +func TestAddProjectAfterPlanUpdate(t *testing.T) { + ts := setupForPaymentTest(t) + err := ts.addClient(1) + require.Nil(t, err) + + // advance epoch to get pairing + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + _, err = ts.servers.SubscriptionServer.Buy(ts.ctx, &subtypes.MsgBuy{ + Creator: ts.clients[0].Addr.String(), + Consumer: ts.clients[0].Addr.String(), + Index: ts.plan.Index, + Duration: 11, + Vrfpk: "", + }) + require.Nil(t, err) + + sub, found := ts.keepers.Subscription.GetSubscription(sdk.UnwrapSDKContext(ts.ctx), ts.clients[0].Addr.String()) + require.True(t, found) + + // advance epoch so the plan edit will be on a different block than the subscription purchase + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + // edit the plan the subscription purchased (allow less CU per epoch) + subPlan, found := ts.keepers.Plans.FindPlan(sdk.UnwrapSDKContext(ts.ctx), sub.PlanIndex, sub.PlanBlock) + require.True(t, found) + oldEpochCuLimit := subPlan.PlanPolicy.EpochCuLimit + subPlan.PlanPolicy.EpochCuLimit -= 50 + err = ts.keepers.Plans.AddPlan(sdk.UnwrapSDKContext(ts.ctx), subPlan) + require.Nil(t, err) + + // add another project under the subcscription + projectData := projectstypes.ProjectData{ + Name: "anotherProject", + Description: "dummyDesc", + Enabled: true, + ProjectKeys: []projectstypes.ProjectKey{ + { + Key: ts.clients[1].Addr.String(), + Types: []projectstypes.ProjectKey_KEY_TYPE{ + projectstypes.ProjectKey_DEVELOPER, + projectstypes.ProjectKey_ADMIN, + }, + Vrfpk: "", + }, + }, + Policy: nil, + } + err = ts.keepers.Subscription.AddProjectToSubscription(sdk.UnwrapSDKContext(ts.ctx), ts.clients[0].Addr.String(), projectData) + require.Nil(t, err) + + proj, _, err := ts.keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ts.ctx), ts.clients[1].Addr.String(), + uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) + require.Nil(t, err) + + // set a new policy to the second project, making it more strict than the old plan but less strict than the new plan + adminPolicy := ts.plan.PlanPolicy + adminPolicy.EpochCuLimit = oldEpochCuLimit - 30 + + err = ts.keepers.Projects.SetPolicy(sdk.UnwrapSDKContext(ts.ctx), []string{proj.Index}, &adminPolicy, + ts.clients[1].Addr.String(), projectstypes.SET_ADMIN_POLICY) + require.Nil(t, err) + + // advance epoch to set the new policy + ts.ctx = testkeeper.AdvanceEpoch(ts.ctx, ts.keepers) + + _, _, _, cuPerEpochLimit, _, _, err := ts.keepers.Pairing.ValidatePairingForClient(sdk.UnwrapSDKContext(ts.ctx), + ts.spec.Index, ts.clients[1].Addr, ts.providers[0].Addr, uint64(sdk.UnwrapSDKContext(ts.ctx).BlockHeight())) + require.Nil(t, err) + + // in terms of strictness: newPlan < adminPolicy < oldPlan but newPlan should not apply to the second project (since it's under a subscription that uses the old plan) + require.Equal(t, adminPolicy.EpochCuLimit, cuPerEpochLimit) +} diff --git a/x/projects/keeper/creation.go b/x/projects/keeper/creation.go index 9dda2f1733..0ef59b4ca1 100644 --- a/x/projects/keeper/creation.go +++ b/x/projects/keeper/creation.go @@ -82,6 +82,8 @@ func (k Keeper) RegisterKey(ctx sdk.Context, key types.ProjectKey, project *type return utils.LavaError(ctx, k.Logger(ctx), "RegisterKey_add_dev_key_failed", details, "adding developer key failed") } } + default: + panic("requested key has an invalid type") } } diff --git a/x/projects/keeper/msg_server_add_project_keys.go b/x/projects/keeper/msg_server_add_project_keys.go index d6fa526dc0..0a315c210c 100644 --- a/x/projects/keeper/msg_server_add_project_keys.go +++ b/x/projects/keeper/msg_server_add_project_keys.go @@ -2,6 +2,7 @@ package keeper import ( "context" + "fmt" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lavanet/lava/x/projects/types" @@ -10,6 +11,14 @@ import ( func (k msgServer) AddProjectKeys(goCtx context.Context, msg *types.MsgAddProjectKeys) (*types.MsgAddProjectKeysResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) + for _, projectKey := range msg.GetProjectKeys() { + for _, keyType := range projectKey.GetTypes() { + if keyType != types.ProjectKey_ADMIN && keyType != types.ProjectKey_DEVELOPER { + return nil, fmt.Errorf("project key must be of type ADMIN(=1) or DEVELOPER(=2). projectKey = %d", keyType) + } + } + } + err := k.AddKeysToProject(ctx, msg.Project, msg.Creator, msg.ProjectKeys) if err != nil { return nil, err diff --git a/x/projects/keeper/msg_server_set_admin_policy.go b/x/projects/keeper/msg_server_set_admin_policy.go index c90b2d59b9..2853f025d1 100644 --- a/x/projects/keeper/msg_server_set_admin_policy.go +++ b/x/projects/keeper/msg_server_set_admin_policy.go @@ -11,7 +11,13 @@ func (k msgServer) SetAdminPolicy(goCtx context.Context, msg *types.MsgSetAdminP ctx := sdk.UnwrapSDKContext(goCtx) policy := msg.GetPolicy() - err := k.SetPolicy(ctx, []string{msg.GetProject()}, &policy, msg.GetCreator(), types.SET_ADMIN_POLICY) + + err := policy.ValidateBasicPolicy() + if err != nil { + return nil, err + } + + err = k.SetPolicy(ctx, []string{msg.GetProject()}, &policy, msg.GetCreator(), types.SET_ADMIN_POLICY) if err != nil { return nil, err } diff --git a/x/projects/keeper/msg_server_set_subscription_policy.go b/x/projects/keeper/msg_server_set_subscription_policy.go index a9c0ad3cfc..cc4efcfe7e 100644 --- a/x/projects/keeper/msg_server_set_subscription_policy.go +++ b/x/projects/keeper/msg_server_set_subscription_policy.go @@ -11,7 +11,13 @@ func (k msgServer) SetSubscriptionPolicy(goCtx context.Context, msg *types.MsgSe ctx := sdk.UnwrapSDKContext(goCtx) policy := msg.GetPolicy() - err := k.SetPolicy(ctx, msg.GetProjects(), &policy, msg.GetCreator(), types.SET_SUBSCRIPTION_POLICY) + + err := policy.ValidateBasicPolicy() + if err != nil { + return nil, err + } + + err = k.SetPolicy(ctx, msg.GetProjects(), &policy, msg.GetCreator(), types.SET_SUBSCRIPTION_POLICY) if err != nil { return nil, err } diff --git a/x/projects/keeper/project_test.go b/x/projects/keeper/project_test.go index 574e52090a..296247ef3e 100644 --- a/x/projects/keeper/project_test.go +++ b/x/projects/keeper/project_test.go @@ -3,15 +3,19 @@ package keeper_test import ( "context" "math" + "strings" "testing" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lavanet/lava/testutil/common" testkeeper "github.com/lavanet/lava/testutil/keeper" "github.com/lavanet/lava/x/projects/types" + subscriptiontypes "github.com/lavanet/lava/x/subscription/types" "github.com/stretchr/testify/require" ) +const projectName = "mockname" + func prepareProjectsData(ctx context.Context, keepers *testkeeper.Keepers) (projects []types.ProjectData) { adm1Addr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() adm2Addr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() @@ -36,7 +40,10 @@ func prepareProjectsData(ctx context.Context, keepers *testkeeper.Keepers) (proj {Key: dev3Addr, Types: typeDevel, Vrfpk: ""}, } - policy1 := &types.Policy{GeolocationProfile: math.MaxUint64} + policy1 := &types.Policy{ + GeolocationProfile: math.MaxUint64, + MaxProvidersToPair: 2, + } templates := []struct { name string @@ -46,8 +53,8 @@ func prepareProjectsData(ctx context.Context, keepers *testkeeper.Keepers) (proj }{ // project with admin key, enabled, has policy {"mock_project_1", true, keys_1_admin, policy1}, - // project with "both" key, disabled, no policy - {"mock_project_2", false, keys_1_admin_dev, nil}, + // project with "both" key, disabled, with policy + {"mock_project_2", false, keys_1_admin_dev, policy1}, // project with 2 keys (one admin, one developer) disabled, no policy {"mock_project_3", false, keys_2_admin_and_dev, nil}, } @@ -88,7 +95,7 @@ func TestCreateDefaultProject(t *testing.T) { } func TestCreateProject(t *testing.T) { - _, keepers, _ctx := testkeeper.InitAllKeepers(t) + servers, keepers, _ctx := testkeeper.InitAllKeepers(t) ctx := sdk.UnwrapSDKContext(_ctx) projectData := prepareProjectsData(_ctx, keepers)[1] @@ -103,11 +110,70 @@ func TestCreateProject(t *testing.T) { _ctx = testkeeper.AdvanceEpoch(_ctx, keepers) ctx = sdk.UnwrapSDKContext(_ctx) - // create another project with the same name, should fail as this is unique - err = keepers.Projects.CreateProject(ctx, subAddr, projectData, plan) + // test invalid project name/description + defaultProjectName := types.ADMIN_PROJECT_NAME + longProjectName := strings.Repeat(defaultProjectName, types.MAX_PROJECT_NAME_LEN+1) + projectNameWithComma := "projectName," + nonAsciiProjectName := "projectName¢" + + projectDescription := "test project" + longProjectDescription := strings.Repeat(projectDescription, types.MAX_PROJECT_DESCRIPTION_LEN+1) + nonAsciiProjectDescription := "projectDesc¢" + + testProjectData := projectData + testProjectData.ProjectKeys = []types.ProjectKey{} + + nameAndDescriptionTests := []struct { + name string + projectName string + projectDescription string + }{ + {"bad projectName (duplicate)", projectName, projectDescription}, + {"bad projectName (too long)", longProjectName, projectDescription}, + {"bad projectName (contains comma)", projectNameWithComma, projectDescription}, + {"bad projectName (non ascii)", nonAsciiProjectName, projectDescription}, + {"bad projectName (empty)", "", projectDescription}, + {"bad projectDescription (too long)", "test1", longProjectDescription}, + {"bad projectDescription (non ascii)", "test2", nonAsciiProjectDescription}, + } + + for _, tt := range nameAndDescriptionTests { + t.Run(tt.name, func(t *testing.T) { + testProjectData.Name = tt.projectName + testProjectData.Description = tt.projectDescription + + err = keepers.Projects.CreateProject(ctx, subAddr, testProjectData, plan) + require.NotNil(t, err) + }) + } + // _ctx = sdk.WrapSDKContext(ctx) + + // continue testing traits that are not related to the project's name/description + // try creating a project with invalid project keys + invalidKeysProjectData := projectData + invalidKeysProjectData.Name = "nonDuplicateProjectName" + invalidKeysProjectData.ProjectKeys = []types.ProjectKey{ + { + Key: subAddr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }, + { + Key: admAddr, + Types: []types.ProjectKey_KEY_TYPE{4}, + Vrfpk: "", + }, + } + + // should fail because there's an invalid key + _, err = servers.SubscriptionServer.AddProject(_ctx, &subscriptiontypes.MsgAddProject{ + Creator: subAddr, + ProjectData: invalidKeysProjectData, + }) require.NotNil(t, err) - // subscription key is not a developer + // get project by developer - subscription key is not a developer, should fail (if it succeeds, it means that the valid project key + // from invalidKeysProjectData was registered, which is not desired!) _, err = keepers.Projects.Developer(_ctx, &types.QueryDeveloperRequest{Developer: subAddr}) require.NotNil(t, err) @@ -158,6 +224,11 @@ func TestAddKeys(t *testing.T) { _, err = servers.ProjectServer.AddProjectKeys(ctx, &types.MsgAddProjectKeys{Creator: dev1Addr, Project: project.Index, ProjectKeys: []types.ProjectKey{pk}}) require.NotNil(t, err) + // admin key adding an invalid key + pk = types.ProjectKey{Key: dev2Addr, Types: []types.ProjectKey_KEY_TYPE{4}} + _, err = servers.ProjectServer.AddProjectKeys(ctx, &types.MsgAddProjectKeys{Creator: admAddr, Project: project.Index, ProjectKeys: []types.ProjectKey{pk}}) + require.NotNil(t, err) + // admin key adding a developer pk = types.ProjectKey{Key: dev2Addr, Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}} _, err = servers.ProjectServer.AddProjectKeys(ctx, &types.MsgAddProjectKeys{Creator: admAddr, Project: project.Index, ProjectKeys: []types.ProjectKey{pk}}) @@ -235,55 +306,80 @@ func SetPolicyTest(t *testing.T, testAdminPolicy bool) { require.Nil(t, err) pk := types.ProjectKey{Key: devAddr, Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}} - keepers.Projects.AddKeysToProject(ctx, projectID, admAddr, []types.ProjectKey{pk}) + err = keepers.Projects.AddKeysToProject(ctx, projectID, admAddr, []types.ProjectKey{pk}) + require.Nil(t, err) spec := common.CreateMockSpec() keepers.Spec.SetSpec(ctx, spec) templates := []struct { - name string - creator string - chainPolicies []types.ChainPolicy - totalCuLimit uint64 - epochCuLimit uint64 - maxProvidersToPair uint64 - validateBasicSuccess bool - setPolicySuccess bool + name string + creator string + projectID string + geolocation uint64 + chainPolicies []types.ChainPolicy + totalCuLimit uint64 + epochCuLimit uint64 + maxProvidersToPair uint64 + setAdminPolicySuccess bool + setSubscriptionPolicySuccess bool }{ { - "valid policy (admin account)", admAddr, + "valid policy (admin account)", admAddr, projectID, uint64(1), []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{spec.Apis[0].Name}}}, - 100, 10, 3, true, true, + 100, 10, 3, true, false, }, + { - "valid policy (subscription account)", subAddr, + "valid policy (subscription account)", subAddr, projectID, uint64(1), []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{spec.Apis[0].Name}}}, 100, 10, 3, true, true, }, + + { + "bad creator (developer account -- not admin)", devAddr, projectID, uint64(1), + []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{spec.Apis[0].Name}}}, + 100, 10, 3, false, false, + }, + { - "bad creator (developer account -- not admin)", devAddr, + "bad projectID (doesn't exist)", devAddr, "fakeProjectId", uint64(1), []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{spec.Apis[0].Name}}}, - 100, 10, 3, true, false, + 100, 10, 3, false, false, + }, + + { + "invalid geolocation (0)", devAddr, projectID, uint64(0), + []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{spec.Apis[0].Name}}}, + 100, 10, 3, false, false, }, + { // note: currently, we don't verify the chain policies - "bad chainID (doesn't exist)", subAddr, + "bad chainID (doesn't exist)", subAddr, projectID, uint64(1), []types.ChainPolicy{{ChainId: "LOL", Apis: []string{spec.Apis[0].Name}}}, 100, 10, 3, true, true, }, + { // note: currently, we don't verify the chain policies - "bad API (doesn't exist)", subAddr, + "bad API (doesn't exist)", subAddr, projectID, uint64(1), []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{"lol"}}}, 100, 10, 3, true, true, }, { - "epoch CU larger than total CU", subAddr, + // note: currently, we don't verify the chain policies + "chainID and API not supported (exist in Lava's specs)", subAddr, projectID, uint64(1), + []types.ChainPolicy{{ChainId: "ETH1", Apis: []string{"eth_accounts"}}}, + 100, 10, 3, true, true, + }, + { + "epoch CU larger than total CU", subAddr, projectID, uint64(1), []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{spec.Apis[0].Name}}}, 10, 100, 3, false, false, }, { - "bad maxProvidersToPair", subAddr, + "bad maxProvidersToPair", subAddr, projectID, uint64(1), []types.ChainPolicy{{ChainId: spec.Index, Apis: []string{spec.Apis[0].Name}}}, 100, 10, 1, false, false, }, @@ -293,7 +389,7 @@ func SetPolicyTest(t *testing.T, testAdminPolicy bool) { t.Run(tt.name, func(t *testing.T) { newPolicy := types.Policy{ ChainPolicies: tt.chainPolicies, - GeolocationProfile: uint64(1), + GeolocationProfile: tt.geolocation, TotalCuLimit: tt.totalCuLimit, EpochCuLimit: tt.epochCuLimit, MaxProvidersToPair: tt.maxProvidersToPair, @@ -303,24 +399,19 @@ func SetPolicyTest(t *testing.T, testAdminPolicy bool) { setAdminPolicyMessage := types.MsgSetAdminPolicy{ Creator: tt.creator, Policy: newPolicy, - Project: projectID, + Project: tt.projectID, } err = setAdminPolicyMessage.ValidateBasic() - if tt.validateBasicSuccess { - require.Nil(t, err) - } else { - require.NotNil(t, err) - return - } + require.Nil(t, err) _, err := servers.ProjectServer.SetAdminPolicy(_ctx, &setAdminPolicyMessage) - if tt.setPolicySuccess { + if tt.setAdminPolicySuccess { require.Nil(t, err) _ctx = testkeeper.AdvanceEpoch(_ctx, keepers) - ctx := sdk.UnwrapSDKContext(_ctx) + ctx = sdk.UnwrapSDKContext(_ctx) - proj, err := keepers.Projects.GetProjectForBlock(ctx, projectID, uint64(ctx.BlockHeight())) + proj, err := keepers.Projects.GetProjectForBlock(ctx, tt.projectID, uint64(ctx.BlockHeight())) require.Nil(t, err) require.Equal(t, newPolicy, *proj.AdminPolicy) @@ -331,35 +422,22 @@ func SetPolicyTest(t *testing.T, testAdminPolicy bool) { setSubscriptionPolicyMessage := types.MsgSetSubscriptionPolicy{ Creator: tt.creator, Policy: newPolicy, - Projects: []string{projectID}, + Projects: []string{tt.projectID}, } err = setSubscriptionPolicyMessage.ValidateBasic() - if tt.validateBasicSuccess { - require.Nil(t, err) - } else { - require.NotNil(t, err) - return - } + require.Nil(t, err) _, err := servers.ProjectServer.SetSubscriptionPolicy(_ctx, &setSubscriptionPolicyMessage) - if tt.creator == subAddr { - // only the subscription consumer should be able to set subscription policy + if tt.setSubscriptionPolicySuccess { require.Nil(t, err) + _ctx = testkeeper.AdvanceEpoch(_ctx, keepers) + ctx = sdk.UnwrapSDKContext(_ctx) - if tt.setPolicySuccess { - require.Nil(t, err) - - _ctx = testkeeper.AdvanceEpoch(_ctx, keepers) - ctx := sdk.UnwrapSDKContext(_ctx) - - proj, err := keepers.Projects.GetProjectForBlock(ctx, projectID, uint64(ctx.BlockHeight())) - require.Nil(t, err) + proj, err := keepers.Projects.GetProjectForBlock(ctx, tt.projectID, uint64(ctx.BlockHeight())) + require.Nil(t, err) - require.Equal(t, newPolicy, *proj.SubscriptionPolicy) - } else { - require.NotNil(t, err) - } + require.Equal(t, newPolicy, *proj.SubscriptionPolicy) } else { require.NotNil(t, err) } @@ -435,3 +513,126 @@ func TestChargeComputeUnits(t *testing.T) { require.Nil(t, err) require.Equal(t, uint64(0), proj.UsedCu) } + +func TestAddDevKeyToSameProjectDifferentBlocks(t *testing.T) { + _, keepers, ctx := testkeeper.InitAllKeepers(t) + + projectName := "mockname1" + subAddr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() + dev1Addr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() + dev2Addr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() + projectID := types.ProjectIndex(subAddr, projectName) + plan := common.CreateMockPlan() + + projectData := types.ProjectData{ + Name: projectName, + Description: "", + Enabled: true, + ProjectKeys: []types.ProjectKey{{ + Key: subAddr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }}, + Policy: &plan.PlanPolicy, + } + err := keepers.Projects.CreateProject(sdk.UnwrapSDKContext(ctx), subAddr, projectData, plan) + require.Nil(t, err) + + ctx = testkeeper.AdvanceBlock(ctx, keepers) + + err = keepers.Projects.AddKeysToProject(sdk.UnwrapSDKContext(ctx), projectID, subAddr, + []types.ProjectKey{{ + Key: dev1Addr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }}) + require.Nil(t, err) + + ctx = testkeeper.AdvanceBlock(ctx, keepers) + + err = keepers.Projects.AddKeysToProject(sdk.UnwrapSDKContext(ctx), projectID, subAddr, + []types.ProjectKey{{ + Key: dev2Addr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }}) + require.Nil(t, err) + + proj, _, err := keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ctx), subAddr, + uint64(sdk.UnwrapSDKContext(ctx).BlockHeight())) + require.Nil(t, err) + + require.Equal(t, 3, len(proj.ProjectKeys)) +} + +func TestAddDevKeyToDifferentProjectsInSameBlock(t *testing.T) { + _, keepers, ctx := testkeeper.InitAllKeepers(t) + plan := common.CreateMockPlan() + + sub1Addr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() + sub2Addr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() + devAddr := common.CreateNewAccount(ctx, *keepers, 10000).Addr.String() + + projectName1 := "mockname1" + projectName2 := "mockname2" + + projectID1 := types.ProjectIndex(sub1Addr, projectName1) + projectID2 := types.ProjectIndex(sub2Addr, projectName2) + + projectData1 := types.ProjectData{ + Name: projectName1, + Description: "", + Enabled: true, + ProjectKeys: []types.ProjectKey{{ + Key: sub1Addr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }}, + Policy: &plan.PlanPolicy, + } + err := keepers.Projects.CreateProject(sdk.UnwrapSDKContext(ctx), sub1Addr, projectData1, plan) + require.Nil(t, err) + + projectData2 := types.ProjectData{ + Name: projectName2, + Description: "", + Enabled: true, + ProjectKeys: []types.ProjectKey{{ + Key: sub2Addr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }}, + Policy: &plan.PlanPolicy, + } + err = keepers.Projects.CreateProject(sdk.UnwrapSDKContext(ctx), sub2Addr, projectData2, plan) + require.Nil(t, err) + + ctx = testkeeper.AdvanceBlock(ctx, keepers) + + err = keepers.Projects.AddKeysToProject(sdk.UnwrapSDKContext(ctx), projectID1, sub1Addr, + []types.ProjectKey{{ + Key: devAddr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }}) + require.Nil(t, err) + + err = keepers.Projects.AddKeysToProject(sdk.UnwrapSDKContext(ctx), projectID2, sub2Addr, + []types.ProjectKey{{ + Key: devAddr, + Types: []types.ProjectKey_KEY_TYPE{types.ProjectKey_DEVELOPER}, + Vrfpk: "", + }}) + require.NotNil(t, err) // should fail since this developer was already added to the first project + + proj1, _, err := keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ctx), sub1Addr, + uint64(sdk.UnwrapSDKContext(ctx).BlockHeight())) + require.Nil(t, err) + + proj2, _, err := keepers.Projects.GetProjectForDeveloper(sdk.UnwrapSDKContext(ctx), sub2Addr, + uint64(sdk.UnwrapSDKContext(ctx).BlockHeight())) + require.Nil(t, err) + + require.Equal(t, 2, len(proj1.ProjectKeys)) + require.Equal(t, 1, len(proj2.ProjectKeys)) +} diff --git a/x/projects/types/message_add_project_keys.go b/x/projects/types/message_add_project_keys.go index 529d4fb237..c22ec7f212 100644 --- a/x/projects/types/message_add_project_keys.go +++ b/x/projects/types/message_add_project_keys.go @@ -44,12 +44,5 @@ func (msg *MsgAddProjectKeys) ValidateBasic() error { return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid creator address (%s)", err) } - for _, projectKey := range msg.GetProjectKeys() { - for _, keyType := range projectKey.GetTypes() { - if keyType != ProjectKey_ADMIN && keyType != ProjectKey_DEVELOPER { - return sdkerrors.Wrapf(ErrInvalidKeyType, "project key must be of type ADMIN(=1) or DEVELOPER(=2). projectKey = %d", keyType) - } - } - } return nil } diff --git a/x/projects/types/message_set_admin_policy.go b/x/projects/types/message_set_admin_policy.go index 7fc8a92eb7..3463e4c566 100644 --- a/x/projects/types/message_set_admin_policy.go +++ b/x/projects/types/message_set_admin_policy.go @@ -44,9 +44,5 @@ func (msg *MsgSetAdminPolicy) ValidateBasic() error { return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid creator address (%s)", err) } - err = msg.GetPolicy().ValidateBasicPolicy() - if err != nil { - return sdkerrors.Wrapf(ErrInvalidPolicy, "invalid policy") - } return nil } diff --git a/x/projects/types/message_set_subscription_policy.go b/x/projects/types/message_set_subscription_policy.go index b3e8fe3056..26986f0773 100644 --- a/x/projects/types/message_set_subscription_policy.go +++ b/x/projects/types/message_set_subscription_policy.go @@ -44,9 +44,5 @@ func (msg *MsgSetSubscriptionPolicy) ValidateBasic() error { return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid creator address (%s)", err) } - err = msg.GetPolicy().ValidateBasicPolicy() - if err != nil { - return sdkerrors.Wrapf(ErrInvalidPolicy, "invalid policy") - } return nil } diff --git a/x/projects/types/project.go b/x/projects/types/project.go index cd3b108296..e4faeccb44 100644 --- a/x/projects/types/project.go +++ b/x/projects/types/project.go @@ -38,7 +38,7 @@ func NewProject(subscriptionAddress string, projectName string, description stri func ValidateProjectNameAndDescription(name string, description string) bool { if strings.Contains(name, ",") || !commontypes.IsASCII(name) || len(name) > MAX_PROJECT_NAME_LEN || len(description) > MAX_PROJECT_DESCRIPTION_LEN || - name == "" { + name == "" || !commontypes.IsASCII(description) { return false } diff --git a/x/subscription/keeper/msg_server_add_project.go b/x/subscription/keeper/msg_server_add_project.go index fc5fee3bc8..dd704693a7 100644 --- a/x/subscription/keeper/msg_server_add_project.go +++ b/x/subscription/keeper/msg_server_add_project.go @@ -2,15 +2,53 @@ package keeper import ( "context" + "strconv" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lavanet/lava/utils" + projectstypes "github.com/lavanet/lava/x/projects/types" "github.com/lavanet/lava/x/subscription/types" ) func (k msgServer) AddProject(goCtx context.Context, msg *types.MsgAddProject) (*types.MsgAddProjectResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) + for _, projectKey := range msg.GetProjectData().ProjectKeys { + _, err := sdk.AccAddressFromBech32(projectKey.GetKey()) + if err != nil { + details := map[string]string{ + "key": projectKey.Key, + "err": err.Error(), + } + return nil, utils.LavaError(ctx, k.Logger(ctx), "AddProject_invalid_project_key", details, "invalid project key") + } + + for _, projectKeyType := range projectKey.Types { + if projectKeyType != projectstypes.ProjectKey_ADMIN && projectKeyType != projectstypes.ProjectKey_DEVELOPER { + details := map[string]string{ + "key": projectKey.Key, + "type": strconv.FormatInt(int64(projectKeyType), 10), + } + return nil, utils.LavaError(ctx, k.Logger(ctx), "AddProject_invalid_project_key_type", details, "invalid project key type (should be 1 or 2)") + } + } + + if !projectstypes.ValidateProjectNameAndDescription(msg.GetProjectData().Name, msg.GetProjectData().Description) { + details := map[string]string{ + "name": msg.GetProjectData().Name, + "description": msg.GetProjectData().Description, + } + return nil, utils.LavaError(ctx, k.Logger(ctx), "AddProject_invalid_project_name_or_description", details, "invalid project name or description (might be too long or include disallowed characters)") + } + + if msg.GetProjectData().Policy.MaxProvidersToPair <= 1 { + details := map[string]string{ + "maxProvidersToPair": strconv.FormatUint(msg.GetProjectData().Policy.MaxProvidersToPair, 10), + } + return nil, utils.LavaError(ctx, k.Logger(ctx), "AddProject_invalid_project_providers_to_pair", details, "invalid project providersToPair (must be larger than one)") + } + } + err := k.Keeper.AddProjectToSubscription(ctx, msg.GetCreator(), msg.GetProjectData()) if err == nil { logger := k.Keeper.Logger(ctx) diff --git a/x/subscription/keeper/subscription_test.go b/x/subscription/keeper/subscription_test.go index 161eb808ce..eb8334a1fc 100644 --- a/x/subscription/keeper/subscription_test.go +++ b/x/subscription/keeper/subscription_test.go @@ -292,15 +292,35 @@ func TestRenewSubscription(t *testing.T) { err = keeper.CreateSubscription(ts.ctx, creator, creator, ts.plans[0].Index, 12, "") require.NotNil(t, err) - // but asking for additional 10 is fine - err = keeper.CreateSubscription(ts.ctx, creator, creator, ts.plans[0].Index, 10, "") + // but asking for additional 9 months (10 would also be fine (the extra month extension below)) + err = keeper.CreateSubscription(ts.ctx, creator, creator, ts.plans[0].Index, 9, "") require.Nil(t, err) sub, found = keeper.GetSubscription(ts.ctx, creator) require.True(t, found) - require.Equal(t, uint64(13), sub.DurationLeft) - require.Equal(t, uint64(10), sub.DurationTotal) + require.Equal(t, uint64(12), sub.DurationLeft) + require.Equal(t, uint64(9), sub.DurationTotal) + + // edit the subscription's plan (allow more CU) + subPlan, found := ts.keepers.Plans.FindPlan(ts.ctx, sub.PlanIndex, sub.PlanBlock) + require.True(t, found) + oldPlanCuPerEpoch := subPlan.PlanPolicy.EpochCuLimit + subPlan.PlanPolicy.EpochCuLimit += 100 + err = keepertest.SimulatePlansProposal(ts.ctx, ts.keepers.Plans, []planstypes.Plan{subPlan}) + require.Nil(t, err) + + // try extending the subscription (normally we could extend with 1 more month, but since the + // subscription's plan changed, the extension should fail) + err = keeper.CreateSubscription(ts.ctx, creator, creator, ts.plans[0].Index, 1, "") + require.NotNil(t, err) + require.Equal(t, uint64(12), sub.DurationLeft) + require.Equal(t, uint64(9), sub.DurationTotal) + + // get the subscription's plan and make sure it uses the old plan + subPlan, found = ts.keepers.Plans.FindPlan(ts.ctx, sub.PlanIndex, sub.PlanBlock) + require.True(t, found) + require.Equal(t, oldPlanCuPerEpoch, subPlan.PlanPolicy.EpochCuLimit) } func TestSubscriptionAdminProject(t *testing.T) { @@ -327,63 +347,107 @@ func TestMonthlyRechargeCU(t *testing.T) { projectKeeper := ts.keepers.Projects account := common.CreateNewAccount(ts._ctx, *ts.keepers, 10000) + anotherAccount := common.CreateNewAccount(ts._ctx, *ts.keepers, 10000) creator := account.Addr.String() - err := keeper.CreateSubscription(ts.ctx, creator, creator, ts.plans[0].Index, 2, "") + err := keeper.CreateSubscription(ts.ctx, creator, creator, ts.plans[0].Index, 3, "") require.Nil(t, err) - block1 := uint64(ts.ctx.BlockHeight()) + // add another project under the subcscription + projectData := projectstypes.ProjectData{ + Name: "anotherProject", + Description: "dummyDesc", + Enabled: true, + ProjectKeys: []projectstypes.ProjectKey{ + { + Key: anotherAccount.Addr.String(), + Types: []projectstypes.ProjectKey_KEY_TYPE{projectstypes.ProjectKey_DEVELOPER}, + Vrfpk: "", + }, + }, + Policy: &projectstypes.Policy{ + GeolocationProfile: uint64(1), + TotalCuLimit: 1000, + EpochCuLimit: 100, + MaxProvidersToPair: 3, + }, + } + err = keeper.AddProjectToSubscription(ts.ctx, creator, projectData) + require.Nil(t, err) - sub, found := keeper.GetSubscription(ts.ctx, creator) - require.True(t, found) + // we'll test both the default project and the second project, which differ in their developers + template := []struct { + name string + subscription string + developer string + usedCuPerProject uint64 // total CU in sub is 1000 -- let each project use 500 + }{ + {"default project", creator, creator, 500}, + {"second project (non-default)", creator, anotherAccount.Addr.String(), 500}, + } + for ti, tt := range template { + t.Run(tt.name, func(t *testing.T) { + block1 := uint64(ts.ctx.BlockHeight()) - ts._ctx = keepertest.AdvanceEpoch(ts._ctx, ts.keepers) - ts.ctx = sdk.UnwrapSDKContext(ts._ctx) + ts._ctx = keepertest.AdvanceEpoch(ts._ctx, ts.keepers) + ts.ctx = sdk.UnwrapSDKContext(ts._ctx) - // use the subscription and the project - keeper.ChargeComputeUnitsToSubscription(ts.ctx, creator, 1000) - require.Equal(t, sub.PrevCuLeft, sub.MonthCuTotal-1000) - proj, _, err := projectKeeper.GetProjectForDeveloper(ts.ctx, creator, block1) - require.Nil(t, err) - err = projectKeeper.ChargeComputeUnitsToProject(ts.ctx, proj, block1, 1000) - require.Nil(t, err) + // charge the subscription + err = keeper.ChargeComputeUnitsToSubscription(ts.ctx, tt.subscription, tt.usedCuPerProject) + require.Nil(t, err) - // verify that project used the CU - proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, creator, block1) - require.Nil(t, err) - require.Equal(t, uint64(1000), proj.UsedCu) + sub, found := keeper.GetSubscription(ts.ctx, tt.subscription) + require.True(t, found) - block2 := uint64(ts.ctx.BlockHeight()) + // verify the CU charge of the subscription is updated correctly + // it depends on the iteration index since the same subscription is charged for all projects + require.Equal(t, sub.MonthCuLeft, sub.MonthCuTotal-tt.usedCuPerProject) + proj, _, err := projectKeeper.GetProjectForDeveloper(ts.ctx, tt.developer, block1) + require.Nil(t, err) - // force fixation entry (by adding project key) - projKey := []projectstypes.ProjectKey{ - { - Key: common.CreateNewAccount(ts._ctx, *ts.keepers, 10000).Addr.String(), - Types: []projectstypes.ProjectKey_KEY_TYPE{projectstypes.ProjectKey_ADMIN}, - }, - } - projectKeeper.AddKeysToProject(ts.ctx, projectstypes.ADMIN_PROJECT_NAME, creator, projKey) + // charge the project + err = projectKeeper.ChargeComputeUnitsToProject(ts.ctx, proj, block1, tt.usedCuPerProject) + require.Nil(t, err) - // fast-forward one months - sub = ts.expireSubscription(sub) - require.Equal(t, uint64(1), sub.DurationLeft) + // verify that project used the CU + proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, tt.developer, block1) + require.Nil(t, err) + require.Equal(t, tt.usedCuPerProject, proj.UsedCu) - block3 := uint64(ts.ctx.BlockHeight()) + block2 := uint64(ts.ctx.BlockHeight()) - // check that subscription and project have renewed CUs, and that the - // project created a snapshot for last month - sub, found = keeper.GetSubscription(ts.ctx, creator) - require.True(t, found) - require.Equal(t, sub.MonthCuLeft, sub.MonthCuTotal) - proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, creator, block1) - require.Nil(t, err) - require.Equal(t, uint64(1000), proj.UsedCu) - proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, creator, block2) - require.Nil(t, err) - require.Equal(t, uint64(1000), proj.UsedCu) - proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, creator, block3) - require.Nil(t, err) - require.Equal(t, uint64(0), proj.UsedCu) + // force fixation entry (by adding project key) + projKey := []projectstypes.ProjectKey{ + { + Key: common.CreateNewAccount(ts._ctx, *ts.keepers, 10000).Addr.String(), + Types: []projectstypes.ProjectKey_KEY_TYPE{projectstypes.ProjectKey_ADMIN}, + }, + } + projectKeeper.AddKeysToProject(ts.ctx, projectstypes.ADMIN_PROJECT_NAME, tt.developer, projKey) + + // fast-forward one month (since we expire the subscription in every iteration, it depends on the iteration number) + sub = ts.expireSubscription(sub) + require.Equal(t, sub.DurationTotal-uint64(ti+1), sub.DurationLeft) + + block3 := uint64(ts.ctx.BlockHeight()) + + // check that subscription and project have renewed CUs, and that the + // project created a snapshot for last month + sub, found = keeper.GetSubscription(ts.ctx, tt.subscription) + require.True(t, found) + require.Equal(t, sub.MonthCuLeft, sub.MonthCuTotal) + + proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, tt.developer, block1) + require.Nil(t, err) + require.Equal(t, tt.usedCuPerProject, proj.UsedCu) + proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, tt.developer, block2) + require.Nil(t, err) + require.Equal(t, tt.usedCuPerProject, proj.UsedCu) + proj, _, err = projectKeeper.GetProjectForDeveloper(ts.ctx, tt.developer, block3) + require.Nil(t, err) + require.Equal(t, uint64(0), proj.UsedCu) + }) + } } func TestExpiryTime(t *testing.T) { @@ -518,9 +582,12 @@ func TestAddProjectToSubscription(t *testing.T) { defaultProjectName := projectstypes.ADMIN_PROJECT_NAME longProjectName := strings.Repeat(defaultProjectName, projectstypes.MAX_PROJECT_NAME_LEN) + projectNameWithComma := "projectName," + nonAsciiProjectName := "projectName¢" projectDescription := "test project" longProjectDescription := strings.Repeat(projectDescription, projectstypes.MAX_PROJECT_DESCRIPTION_LEN) + nonAsciiProjectDescription := "projectDesc¢" template := []struct { name string @@ -536,7 +603,11 @@ func TestAddProjectToSubscription(t *testing.T) { {"bad subscription account (subscription payer account)", subPayerAddr, consumerAddr, "test5", projectDescription, false}, {"bad projectName (duplicate)", consumerAddr, regularAccountAddr, defaultProjectName, projectDescription, false}, {"bad projectName (too long)", consumerAddr, regularAccountAddr, longProjectName, projectDescription, false}, + {"bad projectName (contains comma)", consumerAddr, regularAccountAddr, projectNameWithComma, projectDescription, false}, + {"bad projectName (non ascii)", consumerAddr, regularAccountAddr, nonAsciiProjectName, projectDescription, false}, + {"bad projectName (empty)", consumerAddr, regularAccountAddr, "", projectDescription, false}, {"bad projectDescription (too long)", consumerAddr, regularAccountAddr, "test6", longProjectDescription, false}, + {"bad projectDescription (non ascii)", consumerAddr, regularAccountAddr, "test7", nonAsciiProjectDescription, false}, } for _, tt := range template { diff --git a/x/subscription/types/message_add_project.go b/x/subscription/types/message_add_project.go index ee18f1f706..3d6dd6582a 100644 --- a/x/subscription/types/message_add_project.go +++ b/x/subscription/types/message_add_project.go @@ -44,20 +44,5 @@ func (msg *MsgAddProject) ValidateBasic() error { return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid creator address (%s)", err) } - for _, projectKey := range msg.GetProjectData().ProjectKeys { - _, err = sdk.AccAddressFromBech32(projectKey.GetKey()) - if err != nil { - return sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "invalid address in project key (%s)", err) - } - } - - if !projectstypes.ValidateProjectNameAndDescription(msg.GetProjectData().Name, msg.GetProjectData().Description) { - return sdkerrors.Wrapf(ErrInvalidParameter, "invalid project name/description (name: %s, description: %s). Either name empty, name contains \",\", or name/description long (name_max_len = %d, description_max_len = %d)", msg.GetProjectData().Name, msg.GetProjectData().Description, projectstypes.MAX_PROJECT_NAME_LEN, projectstypes.MAX_PROJECT_DESCRIPTION_LEN) - } - - if msg.GetProjectData().Policy.MaxProvidersToPair <= 1 { - return sdkerrors.Wrapf(ErrInvalidParameter, "project maxProvidersToPair field is invalid (maxProvidersToPair = %v). This field must be greater than 1", msg.GetProjectData().Policy.MaxProvidersToPair) - } - return nil }