Skip to content

Commit

Permalink
fix: race in pg and type filtering in parsers (keploy#1947)
Browse files Browse the repository at this point in the history
* fix: race in pg and type filtering in parsers

Signed-off-by: shivamsouravjha <[email protected]>

* chore: rename functions

Signed-off-by: shivamsouravjha <[email protected]>

* chore: remove print

Signed-off-by: shivamsouravjha <[email protected]>

---------

Signed-off-by: shivamsouravjha <[email protected]>
Former-commit-id: 1cbb043
  • Loading branch information
shivamsouravjha authored Jun 7, 2024
1 parent 96ba3a6 commit 7e58e2d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 32 deletions.
3 changes: 3 additions & 0 deletions pkg/core/proxy/integrations/generic/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ func fuzzyMatch(ctx context.Context, reqBuff [][]byte, mockDb integrations.MockM
var unfilteredMocks []*models.Mock

for _, mock := range mocks {
if mock.Kind != "Generic" {
continue
}
if mock.TestModeInfo.IsFiltered {
filteredMocks = append(filteredMocks, mock)
} else {
Expand Down
12 changes: 9 additions & 3 deletions pkg/core/proxy/integrations/http/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,23 @@ func match(ctx context.Context, logger *zap.Logger, input *req, mockDb integrati
}

mocks, err := mockDb.GetUnFilteredMocks()

var unfilteredMocks []*models.Mock
for _, mock := range mocks {
if mock.Kind != "Http" {
continue
}
unfilteredMocks = append(unfilteredMocks, mock)
}
if err != nil {
utils.LogError(logger, err, "failed to get unfilteredMocks mocks")
return false, nil, errors.New("error while matching the request with the mocks")
}

logger.Debug(fmt.Sprintf("Length of unfilteredMocks:%v", len(mocks)))
logger.Debug(fmt.Sprintf("Length of unfilteredMocks:%v", len(unfilteredMocks)))

var schemaMatched []*models.Mock

for _, mock := range mocks {
for _, mock := range unfilteredMocks {
if ctx.Err() != nil {
return false, nil, ctx.Err()
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/core/proxy/integrations/mongo/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@ func match(ctx context.Context, logger *zap.Logger, mongoRequests []models.Mongo
case <-ctx.Done():
return false, nil, ctx.Err()
default:
tcsMocks, err := mockDb.GetFilteredMocks()
mocks, err := mockDb.GetFilteredMocks()
if err != nil {
return false, nil, fmt.Errorf("error while getting tcs mock: %v", err)
}
var tcsMocks []*models.Mock
for _, mock := range mocks {
if mock.Kind != "Mongo" {
continue
}
tcsMocks = append(tcsMocks, mock)
}
maxMatchScore := 0.0
bestMatchIndex := -1
for tcsIndx, tcsMock := range tcsMocks {
Expand Down
19 changes: 16 additions & 3 deletions pkg/core/proxy/integrations/mysql/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,30 @@ func decodeMySQL(ctx context.Context, logger *zap.Logger, clientConn net.Conn, d
prevRequest := ""
var requestBuffers [][]byte

configMocks, err := mockDb.GetUnFilteredMocks()
mocks, err := mockDb.GetUnFilteredMocks()
if err != nil {
utils.LogError(logger, err, "Failed to get unfiltered mocks")
return err
}

tcsMocks, err := mockDb.GetFilteredMocks()
var configMocks []*models.Mock
for _, mock := range mocks {
if mock.Kind != "SQL" {
continue
}
configMocks = append(configMocks, mock)
}
mocks, err = mockDb.GetFilteredMocks()
if err != nil {
utils.LogError(logger, err, "Failed to get filtered mocks")
return err
}
var tcsMocks []*models.Mock
for _, mock := range mocks {
if mock.Kind != "SQL" {
continue
}
tcsMocks = append(tcsMocks, mock)
}

errCh := make(chan error, 1)

Expand Down
34 changes: 21 additions & 13 deletions pkg/core/proxy/integrations/postgres/v1/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, mutex *sync.Mut
return false, nil, ctx.Err()
default:

tcsMocks, err := mockDb.GetUnFilteredMocks()
mocks, err := mockDb.GetUnFilteredMocks()
var tcsMocks []*models.Mock

for _, mock := range mocks {
if mock.Kind != "Postgres" {
continue
}
tcsMocks = append(tcsMocks, mock)
}
if err != nil {
return false, nil, fmt.Errorf("error while getting tcs mocks %v", err)
}
Expand Down Expand Up @@ -116,10 +124,10 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, mutex *sync.Mut
mutex.Unlock()

initMock := *mock
if len(mock.Spec.PostgresRequests) == len(requestBuffers) {
if len(initMock.Spec.PostgresRequests) == len(requestBuffers) {
for requestIndex, reqBuff := range requestBuffers {
bufStr := base64.StdEncoding.EncodeToString(reqBuff)
encodedMock, err := postgresDecoderBackend(mock.Spec.PostgresRequests[requestIndex])
encodedMock, err := postgresDecoderBackend(initMock.Spec.PostgresRequests[requestIndex])
if err != nil {
logger.Debug("Error while decoding postgres request", zap.Error(err))
}
Expand All @@ -130,16 +138,16 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, mutex *sync.Mut
Payload: "Tg==",
}
return true, []models.Frontend{ssl}, nil
case mock.Spec.PostgresRequests[requestIndex].Identfier == "StartupRequest" && isStartupPacket(reqBuff) && mock.Spec.PostgresRequests[requestIndex].Payload != "AAAACATSFi8=" && mock.Spec.PostgresResponses[requestIndex].AuthType == 10:
logger.Debug("CHANGING TO MD5 for Response", zap.String("mock", mock.Name), zap.String("Req", bufStr))
case initMock.Spec.PostgresRequests[requestIndex].Identfier == "StartupRequest" && isStartupPacket(reqBuff) && initMock.Spec.PostgresRequests[requestIndex].Payload != "AAAACATSFi8=" && initMock.Spec.PostgresResponses[requestIndex].AuthType == 10:
logger.Debug("CHANGING TO MD5 for Response", zap.String("mock", initMock.Name), zap.String("Req", bufStr))
initMock.Spec.PostgresResponses[requestIndex].AuthType = 5
err := mockDb.FlagMockAsUsed(&initMock)
if err != nil {
logger.Error("failed to flag mock as used", zap.Error(err))
}
return true, initMock.Spec.PostgresResponses, nil
case len(encodedMock) > 0 && encodedMock[0] == 'p' && mock.Spec.PostgresRequests[requestIndex].PacketTypes[0] == "p" && reqBuff[0] == 'p':
logger.Debug("CHANGING TO MD5 for Request and Response", zap.String("mock", mock.Name), zap.String("Req", bufStr))
case len(encodedMock) > 0 && encodedMock[0] == 'p' && initMock.Spec.PostgresRequests[requestIndex].PacketTypes[0] == "p" && reqBuff[0] == 'p':
logger.Debug("CHANGING TO MD5 for Request and Response", zap.String("mock", initMock.Name), zap.String("Req", bufStr))

initMock.Spec.PostgresRequests[requestIndex].PasswordMessage.Password = "md5fe4f2f657f01fa1dd9d111d5391e7c07"

Expand Down Expand Up @@ -205,6 +213,7 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, mutex *sync.Mut

}
}

// maintain test prepare statement map for each connection id
getTestPS(requestBuffers, logger, ConnectionID)
}
Expand All @@ -227,12 +236,11 @@ func matchingReadablePG(ctx context.Context, logger *zap.Logger, mutex *sync.Mut
logger.Debug("Matched In Sorted PG Matching Stream", zap.String("mock", matchedMock.Name))
}

// idx = findBinaryStreamMatch(logger, sortedTcsMocks, requestBuffers, sorted)
// if idx != -1 && !matched {
// matched = true
// matchedMock = tcsMocks[idx]
// fmt.Println("Matched In Binary Matching for Sorted", matchedMock.Name)
// }
idx = findBinaryStreamMatch(logger, sortedTcsMocks, requestBuffers, sorted)
if idx != -1 && !matched {
matched = true
matchedMock = tcsMocks[idx]
}
}

if !matched {
Expand Down
26 changes: 14 additions & 12 deletions pkg/core/proxy/mockmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,27 @@ func (m *MockManager) SetUnFilteredMocks(mocks []*models.Mock) {
func (m *MockManager) GetFilteredMocks() ([]*models.Mock, error) {
var tcsMocks []*models.Mock
mocks := m.filtered.getAll()
for _, m := range mocks {
if mock, ok := m.(*models.Mock); ok {
tcsMocks = append(tcsMocks, mock)
} else {
return nil, fmt.Errorf("expected mock instance, got %v", m)
}
//sending copy of mocks instead of actual mocks
mockCopy, err := localMock(mocks)
if err != nil {
return nil, fmt.Errorf("expected mock instance, got %v", m)
}
for _, m := range mockCopy {
tcsMocks = append(tcsMocks, &m)
}
return tcsMocks, nil
}

func (m *MockManager) GetUnFilteredMocks() ([]*models.Mock, error) {
var configMocks []*models.Mock
mocks := m.unfiltered.getAll()
for _, m := range mocks {
if mock, ok := m.(*models.Mock); ok {
configMocks = append(configMocks, mock)
} else {
return nil, fmt.Errorf("expected mock instance, got %v", m)
}
//sending copy of mocks instead of actual mocks
mockCopy, err := localMock(mocks)
if err != nil {
return nil, fmt.Errorf("expected mock instance, got %v", m)
}
for _, m := range mockCopy {
configMocks = append(configMocks, &m)
}
return configMocks, nil
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/core/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,15 @@ func (p *Proxy) globalPassThrough(ctx context.Context, client, dest net.Conn) er
}
}
}

func localMock(copyMock []interface{}) ([]models.Mock, error) {
var copiedMocks []models.Mock
for _, m := range copyMock {
if mock, ok := m.(*models.Mock); ok {
copiedMocks = append(copiedMocks, *mock)
} else {
return nil, fmt.Errorf("expected mock instance, got %v", m)
}
}
return copiedMocks, nil
}

0 comments on commit 7e58e2d

Please sign in to comment.