Skip to content

Commit

Permalink
feat: added option to configure static providers (#1629)
Browse files Browse the repository at this point in the history
* added option to configure static providers

* who doesnt like some lint on comments?

* disabled verifications for static provider on consumer, added static provider on provider side, disabled provider sessions on static provider code

* added unitests for static providers

* fix lock hanging

* added tests

* lint

* added examples prints and script to run static provider
  • Loading branch information
omerlavanet authored Aug 19, 2024
1 parent d1ce606 commit bc2e085
Show file tree
Hide file tree
Showing 17 changed files with 743 additions and 53 deletions.
23 changes: 23 additions & 0 deletions config/consumer_examples/lava_consumer_static_peers.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
endpoints:
- chain-id: LAV1
api-interface: rest
network-address: 127.0.0.1:3360
- chain-id: LAV1
api-interface: tendermintrpc
network-address: 127.0.0.1:3361
- chain-id: LAV1
api-interface: grpc
network-address: 127.0.0.1:3362
static-providers:
- api-interface: tendermintrpc
chain-id: LAV1
node-urls:
- url: 127.0.0.1:2220
- api-interface: grpc
chain-id: LAV1
node-urls:
- url: 127.0.0.1:2220
- api-interface: rest
chain-id: LAV1
node-urls:
- url: 127.0.0.1:2220
1 change: 1 addition & 0 deletions protocol/common/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Test_mode_ctx_key struct{}
const (
PlainTextConnection = "allow-plaintext-connection"
EndpointsConfigName = "endpoints"
StaticProvidersConfigName = "static-providers"
SaveConfigFlagName = "save-conf"
GeolocationFlag = "geolocation"
TestModeFlagName = "test-mode"
Expand Down
23 changes: 22 additions & 1 deletion protocol/integration/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package integration_test
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"sync"
Expand Down Expand Up @@ -31,7 +32,7 @@ type mockConsumerStateTracker struct {
func (m *mockConsumerStateTracker) RegisterForVersionUpdates(ctx context.Context, version *protocoltypes.Version, versionValidator updaters.VersionValidationInf) {
}

func (m *mockConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) {
func (m *mockConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProviders []*lavasession.RPCProviderEndpoint) {
}

func (m *mockConsumerStateTracker) RegisterForSpecUpdates(ctx context.Context, specUpdatable updaters.SpecUpdatable, endpoint lavasession.RPCEndpoint) error {
Expand Down Expand Up @@ -267,6 +268,19 @@ type uniqueAddressGenerator struct {
lock sync.Mutex
}

func isPortInUse(port int) bool {
// Attempt to listen on the port
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
// If there's an error, the port is likely in use
return true
}

// Close the listener immediately if successful
ln.Close()
return false
}

func NewUniqueAddressGenerator() uniqueAddressGenerator {
return uniqueAddressGenerator{
currentPort: minPort,
Expand All @@ -277,6 +291,13 @@ func (ag *uniqueAddressGenerator) GetAddress() string {
ag.lock.Lock()
defer ag.lock.Unlock()

for {
if !isPortInUse(ag.currentPort) {
break
}
ag.currentPort++
}

if ag.currentPort > maxPort {
panic("all ports have been exhausted")
}
Expand Down
75 changes: 74 additions & 1 deletion protocol/integration/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string
require.NoError(t, err)
reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser)
mockReliabilityManager := NewMockReliabilityManager(reliabilityManager)
rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil)
rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false)
listener := rpcprovider.NewProviderListener(ctx, rpcProviderEndpoint.NetworkAddress, "/health")
err = listener.RegisterReceiver(rpcProviderServer, rpcProviderEndpoint)
require.NoError(t, err)
Expand Down Expand Up @@ -1149,3 +1149,76 @@ func TestSameProviderConflictReport(t *testing.T) {
require.True(t, twoProvidersConflictSent)
})
}

func TestConsumerProviderStatic(t *testing.T) {
ctx := context.Background()
// can be any spec and api interface
specId := "LAV1"
apiInterface := spectypes.APIInterfaceTendermintRPC
epoch := uint64(100)
requiredResponses := 1
lavaChainID := "lava"

numProviders := 1

consumerListenAddress := addressGen.GetAddress()
pairingList := map[uint64]*lavasession.ConsumerSessionsWithProvider{}
type providerData struct {
account sigs.Account
endpoint *lavasession.RPCProviderEndpoint
server *rpcprovider.RPCProviderServer
replySetter *ReplySetter
mockChainFetcher *MockChainFetcher
}
providers := []providerData{}

for i := 0; i < numProviders; i++ {
account := sigs.GenerateDeterministicFloatingKey(randomizer)
providerDataI := providerData{account: account}
providers = append(providers, providerDataI)
}
consumerAccount := sigs.GenerateDeterministicFloatingKey(randomizer)
for i := 0; i < numProviders; i++ {
ctx := context.Background()
providerDataI := providers[i]
listenAddress := addressGen.GetAddress()
providers[i].server, providers[i].endpoint, providers[i].replySetter, providers[i].mockChainFetcher, _ = createRpcProvider(t, ctx, consumerAccount.Addr.String(), specId, apiInterface, listenAddress, providerDataI.account, lavaChainID, []string(nil), fmt.Sprintf("provider%d", i))
}
// provider is static
for i := 0; i < numProviders; i++ {
pairingList[uint64(i)] = &lavasession.ConsumerSessionsWithProvider{
PublicLavaAddress: "BANANA" + strconv.Itoa(i),
Endpoints: []*lavasession.Endpoint{
{
NetworkAddress: providers[i].endpoint.NetworkAddress.Address,
Enabled: true,
Geolocation: 1,
},
},
Sessions: map[int64]*lavasession.SingleConsumerSession{},
MaxComputeUnits: 10000,
UsedComputeUnits: 0,
PairingEpoch: epoch,
StaticProvider: true,
}
}
rpcconsumerServer, _ := createRpcConsumer(t, ctx, specId, apiInterface, consumerAccount, consumerListenAddress, epoch, pairingList, requiredResponses, lavaChainID)
require.NotNil(t, rpcconsumerServer)
client := http.Client{}
// consumer sends the relay to a provider with an address BANANA+%d so the provider needs to skip validations for this to work
resp, err := client.Get("http://" + consumerListenAddress + "/status")
require.NoError(t, err)
// we expect provider to fail the request on a verification
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
for i := 0; i < numProviders; i++ {
providers[i].server.StaticProvider = true
}
resp, err = client.Get("http://" + consumerListenAddress + "/status")
require.NoError(t, err)
// we expect provider to fail the request on a verification
require.Equal(t, http.StatusOK, resp.StatusCode)
bodyBytes, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, providers[0].replySetter.replyDataBuf, bodyBytes)
resp.Body.Close()
}
2 changes: 2 additions & 0 deletions protocol/lavasession/consumer_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ type ConsumerSessionsWithProvider struct {
// blocked provider recovery status if 0 currently not used, if 1 a session has tried resume communication with this provider
// if the provider is not blocked at all this field is irrelevant
blockedAndUsedWithChanceForRecoveryStatus uint32
StaticProvider bool
}

func NewConsumerSessionWithProvider(publicLavaAddress string, pairingEndpoints []*Endpoint, maxCu uint64, epoch uint64, stakeSize sdk.Coin) *ConsumerSessionsWithProvider {
Expand Down Expand Up @@ -435,6 +436,7 @@ func (cswp *ConsumerSessionsWithProvider) GetConsumerSessionInstanceFromEndpoint
SessionId: randomSessionId,
Parent: cswp,
EndpointConnection: endpointConnection,
StaticProvider: cswp.StaticProvider,
}

consumerSession.TryUseSession() // we must lock the session so other requests wont get it.
Expand Down
1 change: 1 addition & 0 deletions protocol/lavasession/single_consumer_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type SingleConsumerSession struct {
errorsCount uint64
relayProcessor UsedProvidersInf
providerUniqueId string
StaticProvider bool
}

// returns the expected latency to a threshold.
Expand Down
2 changes: 1 addition & 1 deletion protocol/rpcconsumer/consumer_state_tracker_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 19 additions & 3 deletions protocol/rpcconsumer/rpcconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/lavanet/lava/v2/protocol/metrics"
"github.com/lavanet/lava/v2/protocol/performance"
"github.com/lavanet/lava/v2/protocol/provideroptimizer"
"github.com/lavanet/lava/v2/protocol/rpcprovider"
"github.com/lavanet/lava/v2/protocol/statetracker"
"github.com/lavanet/lava/v2/protocol/statetracker/updaters"
"github.com/lavanet/lava/v2/protocol/upgrade"
Expand Down Expand Up @@ -89,7 +90,7 @@ func (s *strategyValue) Type() string {

type ConsumerStateTrackerInf interface {
RegisterForVersionUpdates(ctx context.Context, version *protocoltypes.Version, versionValidator updaters.VersionValidationInf)
RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager)
RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager, staticProvidersList []*lavasession.RPCProviderEndpoint)
RegisterForSpecUpdates(ctx context.Context, specUpdatable updaters.SpecUpdatable, endpoint lavasession.RPCEndpoint) error
RegisterFinalizationConsensusForUpdates(context.Context, *finalizationconsensus.FinalizationConsensus)
RegisterForDowntimeParamsUpdates(ctx context.Context, downtimeParamsUpdatable updaters.DowntimeParamsUpdatable) error
Expand Down Expand Up @@ -121,6 +122,7 @@ type rpcConsumerStartOptions struct {
cmdFlags common.ConsumerCmdFlags
stateShare bool
refererData *chainlib.RefererData
staticProvidersList []*lavasession.RPCProviderEndpoint // define static providers as backup to lava providers
}

// spawns a new RPCConsumer server with all it's processes and internals ready for communications
Expand Down Expand Up @@ -287,7 +289,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt
activeSubscriptionProvidersStorage := lavasession.NewActiveSubscriptionProvidersStorage()
consumerSessionManager := lavasession.NewConsumerSessionManager(rpcEndpoint, optimizer, consumerMetricsManager, consumerReportsManager, consumerAddr.String(), activeSubscriptionProvidersStorage)
// Register For Updates
rpcc.consumerStateTracker.RegisterConsumerSessionManagerForPairingUpdates(ctx, consumerSessionManager)
rpcc.consumerStateTracker.RegisterConsumerSessionManagerForPairingUpdates(ctx, consumerSessionManager, options.staticProvidersList)

var relaysMonitor *metrics.RelaysMonitor
if options.cmdFlags.RelaysHealthEnableFlag {
Expand Down Expand Up @@ -505,6 +507,20 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77
if gasPricesStr == "" {
gasPricesStr = statetracker.DefaultGasPrice
}

// check if StaticProvidersConfigName exists in viper, if it does parse it with ParseStaticProvider function
var staticProviderEndpoints []*lavasession.RPCProviderEndpoint
if viper.IsSet(common.StaticProvidersConfigName) {
staticProviderEndpoints, err = rpcprovider.ParseEndpointsCustomName(viper.GetViper(), common.StaticProvidersConfigName, geolocation)
if err != nil {
return utils.LavaFormatError("invalid static providers definition", err)
}
for _, endpoint := range staticProviderEndpoints {
utils.LavaFormatInfo("Static Provider Endpoint:", utils.Attribute{Key: "Urls", Value: endpoint.NodeUrls}, utils.Attribute{Key: "Chain ID", Value: endpoint.ChainID}, utils.Attribute{Key: "API Interface", Value: endpoint.ApiInterface})
}
}

// set up the txFactory with gas adjustments and gas
txFactory = txFactory.WithGasAdjustment(viper.GetFloat64(flags.FlagGasAdjustment))
txFactory = txFactory.WithGasPrices(gasPricesStr)
utils.LavaFormatInfo("Setting gas for tx Factory", utils.LogAttr("gas-prices", gasPricesStr), utils.LogAttr("gas-adjustment", txFactory.GasAdjustment()))
Expand Down Expand Up @@ -560,7 +576,7 @@ rpcconsumer consumer_examples/full_consumer_example.yml --cache-be "127.0.0.1:77
}

rpcConsumerSharedState := viper.GetBool(common.SharedStateFlag)
err = rpcConsumer.Start(ctx, &rpcConsumerStartOptions{txFactory, clientCtx, rpcEndpoints, requiredResponses, cache, strategyFlag.Strategy, maxConcurrentProviders, analyticsServerAddressess, consumerPropagatedFlags, rpcConsumerSharedState, refererData})
err = rpcConsumer.Start(ctx, &rpcConsumerStartOptions{txFactory, clientCtx, rpcEndpoints, requiredResponses, cache, strategyFlag.Strategy, maxConcurrentProviders, analyticsServerAddressess, consumerPropagatedFlags, rpcConsumerSharedState, refererData, staticProviderEndpoints})
return err
},
}
Expand Down
14 changes: 10 additions & 4 deletions protocol/rpcconsumer/rpcconsumer_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1043,16 +1043,22 @@ func (rpccs *RPCConsumerServer) relayInner(ctx context.Context, singleConsumerSe

filteredHeaders, _, ignoredHeaders := rpccs.chainParser.HandleHeaders(reply.Metadata, chainMessage.GetApiCollection(), spectypes.Header_pass_reply)
reply.Metadata = filteredHeaders
err = lavaprotocol.VerifyRelayReply(ctx, reply, relayRequest, providerPublicAddress)
if err != nil {
return 0, err, false

// check the signature on the reply
if !singleConsumerSession.StaticProvider {
err = lavaprotocol.VerifyRelayReply(ctx, reply, relayRequest, providerPublicAddress)
if err != nil {
return 0, err, false
}
}

reply.Metadata = append(reply.Metadata, ignoredHeaders...)

// TODO: response data sanity, check its under an expected format add that format to spec
enabled, _ := rpccs.chainParser.DataReliabilityParams()
if enabled {
if enabled && !singleConsumerSession.StaticProvider {
// TODO: allow static providers to detect hash mismatches,
// triggering conflict with them is impossible so we skip this for now, but this can be used to block malicious providers
finalizedBlocks, err := finalizationverification.VerifyFinalizationData(reply, relayRequest, providerPublicAddress, rpccs.ConsumerAddress, existingSessionLatestBlock, int64(blockDistanceForFinalizedData), int64(blocksInFinalizationProof))
if err != nil {
if sdkerrors.IsOf(err, protocolerrors.ProviderFinalizationDataAccountabilityError) {
Expand Down
21 changes: 21 additions & 0 deletions protocol/rpcprovider/rewardserver/reward_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ type PaymentConfiguration struct {
shouldAddExpectedPayment bool
}

// used to disable provider rewards claiming
type DisabledRewardServer struct{}

func (rws *DisabledRewardServer) SendNewProof(ctx context.Context, proof *pairingtypes.RelaySession, epoch uint64, consumerAddr string, apiInterface string) (existingCU uint64, updatedWithProof bool) {
return 0, true
}

func (rws *DisabledRewardServer) SubscribeStarted(consumer string, epoch uint64, subscribeID string) {
// TODO: hold off reward claims for subscription while this is still active
}

func (rws *DisabledRewardServer) SubscribeEnded(consumer string, epoch uint64, subscribeID string) {
// TODO: can collect now
}

type RewardServer struct {
rewardsTxSender RewardsTxSender
lock sync.RWMutex
Expand Down Expand Up @@ -464,6 +479,9 @@ func (rws *RewardServer) updateCUPaid(cu uint64) {
}

func (rws *RewardServer) AddDataBase(specId string, providerPublicAddress string, shardID uint) {
if rws == nil {
return
}
// the db itself doesn't need locks. as it self manages locks inside.
// but opening a db can race. (NewLocalDB) so we lock this method.
// Also, we construct the in-memory rewards from the DB, so that needs a lock as well
Expand All @@ -477,6 +495,9 @@ func (rws *RewardServer) AddDataBase(specId string, providerPublicAddress string
}

func (rws *RewardServer) CloseAllDataBases() error {
if rws == nil {
return nil
}
return rws.rewardDB.Close()
}

Expand Down
Loading

0 comments on commit bc2e085

Please sign in to comment.