Skip to content

Commit

Permalink
PRT-637: Allow gRPC & REST Protocol Header Proxying-FIXED (lavanet#508)
Browse files Browse the repository at this point in the history
* prt-637 changes added

* redundand code removed
  • Loading branch information
candostyavuz authored May 25, 2023
1 parent 170ea39 commit e461c0d
Show file tree
Hide file tree
Showing 19 changed files with 461 additions and 109 deletions.
6 changes: 6 additions & 0 deletions proto/pairing/relay.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ message RelayPrivateData {
int64 request_block = 4;
string api_interface = 5;
bytes salt = 6;
repeated Metadata metadata = 7 [(gogoproto.nullable) = false];
}

message Metadata {
string name = 1;
string value = 2;
}

message RelayRequest {
Expand Down
3 changes: 2 additions & 1 deletion protocol/chainlib/chainlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewChainListener(ctx context.Context, listenEndpoint *lavasession.RPCEndpoi
}

type ChainParser interface {
ParseMsg(url string, data []byte, connectionType string) (ChainMessage, error)
ParseMsg(url string, data []byte, connectionType string, metadata []pairingtypes.Metadata) (ChainMessage, error)
SetSpec(spec spectypes.Spec)
DataReliabilityParams() (enabled bool, dataReliabilityThreshold uint32)
ChainBlockStats() (allowedBlockLagForQosSync int64, averageBlockTime time.Duration, blockDistanceForFinalizedData uint32, blocksInFinalizationProof uint32)
Expand All @@ -69,6 +69,7 @@ type RelaySender interface {
connectionType string,
dappID string,
analytics *metrics.RelayMetrics,
metadataValues []pairingtypes.Metadata,
) (*pairingtypes.RelayReply, *pairingtypes.Relayer_RelaySubscribeClient, error)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/jhump/protoreflect/grpcreflect"
"github.com/lavanet/lava/protocol/parser"
"github.com/lavanet/lava/utils"
pairingtypes "github.com/lavanet/lava/x/pairing/types"
"google.golang.org/grpc/codes"
)

Expand All @@ -21,6 +22,7 @@ type GrpcMessage struct {
Path string
methodDesc *desc.MethodDescriptor
formatter grpcurl.Formatter
Header []pairingtypes.Metadata
}

// GetParams will be deprecated after we remove old client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"strings"

"github.com/lavanet/lava/protocol/parser"
pairingtypes "github.com/lavanet/lava/x/pairing/types"
)

type RestMessage struct {
Msg []byte
Path string
SpecPath string
Header []pairingtypes.Metadata
}

// GetParams will be deprecated after we remove old client
Expand Down
32 changes: 27 additions & 5 deletions protocol/chainlib/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func NewGrpcChainParser() (chainParser *GrpcChainParser, err error) {

func (apip *GrpcChainParser) CraftMessage(serviceApi spectypes.ServiceApi, craftData *CraftData) (ChainMessageForSend, error) {
if craftData != nil {
return apip.ParseMsg(craftData.Path, craftData.Data, craftData.ConnectionType)
return apip.ParseMsg(craftData.Path, craftData.Data, craftData.ConnectionType, nil)
}

grpcMessage := &rpcInterfaceMessages.GrpcMessage{
Expand All @@ -58,7 +58,7 @@ func (apip *GrpcChainParser) CraftMessage(serviceApi spectypes.ServiceApi, craft
}

// ParseMsg parses message data into chain message object
func (apip *GrpcChainParser) ParseMsg(url string, data []byte, connectionType string) (ChainMessage, error) {
func (apip *GrpcChainParser) ParseMsg(url string, data []byte, connectionType string, metadata []pairingtypes.Metadata) (ChainMessage, error) {
// Guard that the GrpcChainParser instance exists
if apip == nil {
return nil, errors.New("GrpcChainParser not defined")
Expand All @@ -77,8 +77,9 @@ func (apip *GrpcChainParser) ParseMsg(url string, data []byte, connectionType st

// Construct grpcMessage
grpcMessage := rpcInterfaceMessages.GrpcMessage{
Msg: data,
Path: url,
Msg: data,
Path: url,
Header: metadata,
}

// // Fetch requested block, it is used for data reliability
Expand Down Expand Up @@ -216,10 +217,11 @@ func (apil *GrpcChainListener) Serve(ctx context.Context) {
ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())
msgSeed := apil.logger.GetMessageSeed()
metadataValues, _ := metadata.FromIncomingContext(ctx)
grpcHeaders := convertToMetadataGrpc(metadataValues)
utils.LavaFormatInfo("GRPC Got Relay ", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "method", Value: method})
var relayReply *pairingtypes.RelayReply
metricsData := metrics.NewRelayAnalytics("NoDappID", apil.endpoint.ChainID, apiInterface)
relayReply, _, err := apil.relaySender.SendRelay(ctx, method, string(reqBody), "", "NoDappID", metricsData)
relayReply, _, err := apil.relaySender.SendRelay(ctx, method, string(reqBody), "", "NoDappID", metricsData, grpcHeaders)
go apil.logger.AddMetricForGrpc(metricsData, err, &metadataValues)

if err != nil {
Expand Down Expand Up @@ -283,6 +285,15 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
if !ok {
return nil, "", nil, utils.LavaFormatError("invalid message type in grpc failed to cast RPCInput from chainMessage", nil, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "rpcMessage", Value: rpcInputMessage})
}
if len(nodeMessage.Header) > 0 {
metadataMap := make(map[string]string)
for _, metaData := range nodeMessage.Header {
metadataMap[metaData.Name] = metaData.Value
}
md := metadata.New(metadataMap)
ctx = metadata.NewOutgoingContext(ctx, md)
}

relayTimeout := common.LocalNodeTimePerCu(chainMessage.GetServiceApi().ComputeUnits)
// check if this API is hanging (waiting for block confirmation)
if chainMessage.GetInterface().Category.HangingApi {
Expand Down Expand Up @@ -374,6 +385,17 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
return reply, "", nil, nil
}

func convertToMetadataGrpc(md map[string][]string) []pairingtypes.Metadata {
metadata := make([]pairingtypes.Metadata, len(md))
indexer := 0
for k, v := range md {
metadata[indexer] = pairingtypes.Metadata{Name: k, Value: v[0]}
indexer += 1
}
fmt.Println("metadata: ", metadata)
return metadata
}

func marshalJSON(msg proto.Message) ([]byte, error) {
if dyn, ok := msg.(*dynamic.Message); ok {
return dyn.MarshalJSON()
Expand Down
4 changes: 2 additions & 2 deletions protocol/chainlib/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestGRPChainParser_NilGuard(t *testing.T) {
apip.DataReliabilityParams()
apip.ChainBlockStats()
apip.getSupportedApi("")
apip.ParseMsg("", []byte{}, "")
apip.ParseMsg("", []byte{}, "", nil)
}

func TestGRPCGetSupportedApi(t *testing.T) {
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestGRPCParseMessage(t *testing.T) {
},
}

msg, err := apip.ParseMsg("API1", []byte("test message"), spectypes.APIInterfaceGrpc)
msg, err := apip.ParseMsg("API1", []byte("test message"), spectypes.APIInterfaceGrpc, nil)

assert.Nil(t, err)
assert.Equal(t, msg.GetServiceApi().Name, apip.serverApis["API1"].Name)
Expand Down
8 changes: 4 additions & 4 deletions protocol/chainlib/jsonRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewJrpcChainParser() (chainParser *JsonRPCChainParser, err error) {

func (apip *JsonRPCChainParser) CraftMessage(serviceApi spectypes.ServiceApi, craftData *CraftData) (ChainMessageForSend, error) {
if craftData != nil {
return apip.ParseMsg("", craftData.Data, craftData.ConnectionType)
return apip.ParseMsg("", craftData.Data, craftData.ConnectionType, nil)
}

msg := rpcInterfaceMessages.JsonrpcMessage{
Expand All @@ -53,7 +53,7 @@ func (apip *JsonRPCChainParser) CraftMessage(serviceApi spectypes.ServiceApi, cr
}

// this func parses message data into chain message object
func (apip *JsonRPCChainParser) ParseMsg(url string, data []byte, connectionType string) (ChainMessage, error) {
func (apip *JsonRPCChainParser) ParseMsg(url string, data []byte, connectionType string, metadata []pairingtypes.Metadata) (ChainMessage, error) {
// Guard that the JsonRPCChainParser instance exists
if apip == nil {
return nil, errors.New("JsonRPCChainParser not defined")
Expand Down Expand Up @@ -246,7 +246,7 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context) {
defer cancel() // incase there's a problem make sure to cancel the connection
utils.LavaFormatDebug("ws in <<<", utils.Attribute{Key: "seed", Value: msgSeed}, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "msg", Value: msg}, utils.Attribute{Key: "dappID", Value: dappID})
metricsData := metrics.NewRelayAnalytics(dappID, chainID, apiInterface)
reply, replyServer, err := apil.relaySender.SendRelay(ctx, "", string(msg), http.MethodPost, dappID, metricsData)
reply, replyServer, err := apil.relaySender.SendRelay(ctx, "", string(msg), http.MethodPost, dappID, metricsData, nil)
go apil.logger.AddMetricForWebSocket(metricsData, err, websockConn)

if err != nil {
Expand Down Expand Up @@ -309,7 +309,7 @@ func (apil *JsonRPCChainListener) Serve(ctx context.Context) {
if test_mode {
apil.logger.LogTestMode(fiberCtx)
}
reply, _, err := apil.relaySender.SendRelay(ctx, "", string(fiberCtx.Body()), http.MethodPost, dappID, metricsData)
reply, _, err := apil.relaySender.SendRelay(ctx, "", string(fiberCtx.Body()), http.MethodPost, dappID, metricsData, nil)
go apil.logger.AddMetricForHttp(metricsData, err, fiberCtx.GetReqHeaders())
if err != nil {
// Get unique GUID response
Expand Down
4 changes: 2 additions & 2 deletions protocol/chainlib/jsonRPC_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestJSONChainParser_NilGuard(t *testing.T) {
apip.DataReliabilityParams()
apip.ChainBlockStats()
apip.getSupportedApi("")
apip.ParseMsg("", []byte{}, "")
apip.ParseMsg("", []byte{}, "", nil)
}

func TestJSONGetSupportedApi(t *testing.T) {
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestJSONParseMessage(t *testing.T) {

marshalledData, _ := json.Marshal(data)

msg, err := apip.ParseMsg("API1", marshalledData, spectypes.APIInterfaceJsonRPC)
msg, err := apip.ParseMsg("API1", marshalledData, spectypes.APIInterfaceJsonRPC, nil)

assert.Nil(t, err)
assert.Equal(t, msg.GetServiceApi().Name, apip.serverApis["API1"].Name)
Expand Down
38 changes: 29 additions & 9 deletions protocol/chainlib/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NewRestChainParser() (chainParser *RestChainParser, err error) {
func (apip *RestChainParser) CraftMessage(serviceApi spectypes.ServiceApi, craftData *CraftData) (ChainMessageForSend, error) {
if craftData != nil {
// chain fetcher sends the replaced request inside data
return apip.ParseMsg(string(craftData.Data), nil, craftData.ConnectionType)
return apip.ParseMsg(string(craftData.Data), nil, craftData.ConnectionType, nil)
}

restMessage := rpcInterfaceMessages.RestMessage{
Expand All @@ -52,7 +52,7 @@ func (apip *RestChainParser) CraftMessage(serviceApi spectypes.ServiceApi, craft
}

// ParseMsg parses message data into chain message object
func (apip *RestChainParser) ParseMsg(url string, data []byte, connectionType string) (ChainMessage, error) {
func (apip *RestChainParser) ParseMsg(url string, data []byte, connectionType string, metadata []pairingtypes.Metadata) (ChainMessage, error) {
// Guard that the RestChainParser instance exists
if apip == nil {
return nil, errors.New("RestChainParser not defined")
Expand All @@ -74,14 +74,16 @@ func (apip *RestChainParser) ParseMsg(url string, data []byte, connectionType st

// Construct restMessage
restMessage := rpcInterfaceMessages.RestMessage{
Msg: data,
Path: url,
Msg: data,
Path: url,
Header: metadata,
}
if connectionType == http.MethodGet {
// support for optional params, our listener puts them inside Msg data
restMessage = rpcInterfaceMessages.RestMessage{
Msg: nil,
Path: url + string(data),
Msg: nil,
Path: url + string(data),
Header: metadata,
}
}
// add spec path to rest message so we can extract the requested block.
Expand Down Expand Up @@ -229,6 +231,8 @@ func (apil *RestChainListener) Serve(ctx context.Context) {

path := "/" + c.Params("*")

metadataValues := c.GetReqHeaders()
restHeaders := convertToMetadataRest(metadataValues)
ctx, cancel := context.WithCancel(context.Background())
ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())
defer cancel() // incase there's a problem make sure to cancel the connection
Expand All @@ -239,7 +243,7 @@ func (apil *RestChainListener) Serve(ctx context.Context) {
analytics := metrics.NewRelayAnalytics(dappID, chainID, apiInterface)
utils.LavaFormatInfo("in <<<", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "path", Value: path}, utils.Attribute{Key: "dappID", Value: dappID}, utils.Attribute{Key: "msgSeed", Value: msgSeed})
requestBody := string(c.Body())
reply, _, err := apil.relaySender.SendRelay(ctx, path, requestBody, http.MethodPost, dappID, analytics)
reply, _, err := apil.relaySender.SendRelay(ctx, path, requestBody, http.MethodPost, dappID, analytics, restHeaders)
go apil.logger.AddMetricForHttp(analytics, err, c.GetReqHeaders())

if err != nil {
Expand Down Expand Up @@ -276,12 +280,14 @@ func (apil *RestChainListener) Serve(ctx context.Context) {
dappID := extractDappIDFromFiberContext(c)
analytics := metrics.NewRelayAnalytics(dappID, chainID, apiInterface)

metadataValues := c.GetReqHeaders()
restHeaders := convertToMetadataRest(metadataValues)
ctx, cancel := context.WithCancel(context.Background())
ctx = utils.WithUniqueIdentifier(ctx, utils.GenerateUniqueIdentifier())
defer cancel() // incase there's a problem make sure to cancel the connection
utils.LavaFormatInfo("in <<<", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "path", Value: path}, utils.Attribute{Key: "dappID", Value: dappID}, utils.Attribute{Key: "msgSeed", Value: msgSeed})

reply, _, err := apil.relaySender.SendRelay(ctx, path, query, http.MethodGet, dappID, analytics)
reply, _, err := apil.relaySender.SendRelay(ctx, path, query, http.MethodGet, dappID, analytics, restHeaders)
go apil.logger.AddMetricForHttp(analytics, err, c.GetReqHeaders())
if err != nil {
// Get unique GUID response
Expand Down Expand Up @@ -339,7 +345,6 @@ func (rcp *RestChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
if !ok {
return nil, "", nil, utils.LavaFormatError("invalid message type in rest, failed to cast RPCInput from chainMessage", nil, utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "rpcMessage", Value: rpcInputMessage})
}

var connectionTypeSlected string = http.MethodGet
// if ConnectionType is default value or empty we will choose http.MethodGet otherwise choosing the header type provided
if chainMessage.GetInterface().Type != "" {
Expand Down Expand Up @@ -367,6 +372,11 @@ func (rcp *RestChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
req.Header.Set("Content-Type", "application/json")
}

if len(nodeMessage.Header) > 0 {
for _, metadata := range nodeMessage.Header {
req.Header.Set(metadata.Name, metadata.Value)
}
}
rcp.NodeUrl.SetAuthHeaders(ctx, req.Header.Set)
rcp.NodeUrl.SetIpForwardingIfNecessary(ctx, req.Header.Set)

Expand All @@ -393,3 +403,13 @@ func (rcp *RestChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
}
return reply, "", nil, nil
}

func convertToMetadataRest(md map[string]string) []pairingtypes.Metadata {
metadata := make([]pairingtypes.Metadata, len(md))
indexer := 0
for k, v := range md {
metadata[indexer] = pairingtypes.Metadata{Name: k, Value: v}
indexer += 1
}
return metadata
}
4 changes: 2 additions & 2 deletions protocol/chainlib/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestRestChainParser_NilGuard(t *testing.T) {
apip.DataReliabilityParams()
apip.ChainBlockStats()
apip.getSupportedApi("")
apip.ParseMsg("", []byte{}, "")
apip.ParseMsg("", []byte{}, "", nil)
}

func TestRestGetSupportedApi(t *testing.T) {
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestRestParseMessage(t *testing.T) {
},
}

msg, err := apip.ParseMsg("API1", []byte("test message"), spectypes.APIInterfaceRest)
msg, err := apip.ParseMsg("API1", []byte("test message"), spectypes.APIInterfaceRest, nil)

assert.Nil(t, err)
assert.Equal(t, msg.GetServiceApi().Name, apip.serverApis["API1"].Name)
Expand Down
10 changes: 5 additions & 5 deletions protocol/chainlib/tendermintRPC.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewTendermintRpcChainParser() (chainParser *TendermintChainParser, err erro

func (apip *TendermintChainParser) CraftMessage(serviceApi spectypes.ServiceApi, craftData *CraftData) (ChainMessageForSend, error) {
if craftData != nil {
return apip.ParseMsg("", craftData.Data, craftData.ConnectionType)
return apip.ParseMsg("", craftData.Data, craftData.ConnectionType, nil)
}

msg := rpcInterfaceMessages.JsonrpcMessage{
Expand All @@ -54,7 +54,7 @@ func (apip *TendermintChainParser) CraftMessage(serviceApi spectypes.ServiceApi,
}

// ParseMsg parses message data into chain message object
func (apip *TendermintChainParser) ParseMsg(url string, data []byte, connectionType string) (ChainMessage, error) {
func (apip *TendermintChainParser) ParseMsg(url string, data []byte, connectionType string, metadata []pairingtypes.Metadata) (ChainMessage, error) {
// Guard that the TendermintChainParser instance exists
if apip == nil {
return nil, errors.New("TendermintChainParser not defined")
Expand Down Expand Up @@ -288,7 +288,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context) {
utils.LavaFormatInfo("ws in <<<", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "seed", Value: msgSeed}, utils.Attribute{Key: "msg", Value: msg}, utils.Attribute{Key: "dappID", Value: dappID})

metricsData := metrics.NewRelayAnalytics(dappID, chainID, apiInterface)
reply, replyServer, err := apil.relaySender.SendRelay(ctx, "", string(msg), "", dappID, metricsData)
reply, replyServer, err := apil.relaySender.SendRelay(ctx, "", string(msg), "", dappID, metricsData, nil)
go apil.logger.AddMetricForWebSocket(metricsData, err, c)
if err != nil {
apil.logger.AnalyzeWebSocketErrorAndWriteMessage(c, mt, err, msgSeed, msg, "tendermint")
Expand Down Expand Up @@ -347,7 +347,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context) {
defer cancel() // incase there's a problem make sure to cancel the connection

utils.LavaFormatInfo("in <<<", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "seed", Value: msgSeed}, utils.Attribute{Key: "msg", Value: c.Body()}, utils.Attribute{Key: "dappID", Value: dappID})
reply, _, err := apil.relaySender.SendRelay(ctx, "", string(c.Body()), "", dappID, metricsData)
reply, _, err := apil.relaySender.SendRelay(ctx, "", string(c.Body()), "", dappID, metricsData, nil)
go apil.logger.AddMetricForHttp(metricsData, err, c.GetReqHeaders())

if err != nil {
Expand Down Expand Up @@ -386,7 +386,7 @@ func (apil *TendermintRpcChainListener) Serve(ctx context.Context) {
defer cancel() // incase there's a problem make sure to cancel the connection
utils.LavaFormatInfo("urirpc in <<<", utils.Attribute{Key: "GUID", Value: ctx}, utils.Attribute{Key: "seed", Value: msgSeed}, utils.Attribute{Key: "msg", Value: path}, utils.Attribute{Key: "dappID", Value: dappID})
metricsData := metrics.NewRelayAnalytics(dappID, chainID, apiInterface)
reply, _, err := apil.relaySender.SendRelay(ctx, path+query, "", "", dappID, metricsData)
reply, _, err := apil.relaySender.SendRelay(ctx, path+query, "", "", dappID, metricsData, nil)
go apil.logger.AddMetricForHttp(metricsData, err, c.GetReqHeaders())

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions protocol/chainlib/tendermintRPC_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestTendermintChainParser_NilGuard(t *testing.T) {
apip.DataReliabilityParams()
apip.ChainBlockStats()
apip.getSupportedApi("")
apip.ParseMsg("", []byte{}, "")
apip.ParseMsg("", []byte{}, "", nil)
}

func TestTendermintGetSupportedApi(t *testing.T) {
Expand Down Expand Up @@ -119,7 +119,7 @@ func TestTendermintParseMessage(t *testing.T) {

marshalledData, _ := json.Marshal(data)

msg, err := apip.ParseMsg("API1", marshalledData, spectypes.APIInterfaceTendermintRPC)
msg, err := apip.ParseMsg("API1", marshalledData, spectypes.APIInterfaceTendermintRPC, nil)

assert.Nil(t, err)
assert.Equal(t, msg.GetServiceApi().Name, apip.serverApis["API1"].Name)
Expand Down
Loading

0 comments on commit e461c0d

Please sign in to comment.