Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* conn/message: add mutex to routes/codes maps

When running in high concurrency scenarios the routes/codes maps may
be updated by concurrent goroutines at the same time, resulting in a
race condition while calling `SetDictionary`.

This patch adds a mutex to control access to these maps.

Signed-off-by: Rodrigo Chacon <[email protected]>

* conn/message: protect routes map exposure on GetDictionary

Since Go maps work as pointers to the internal data structures, when
calling GetDictionary we're actually exposing the internal routes map
for its callers.

This commit copies the routes map into a new map instance to prevent
its exposure.

Signed-off-by: Rodrigo Chacon <[email protected]>

Co-authored-by: Rodrigo Chacon <[email protected]>
  • Loading branch information
victor-carvalho and rochacon authored Jun 17, 2021
1 parent 5b797cb commit dbe567b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
16 changes: 13 additions & 3 deletions conn/message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"strings"
"sync"
)

// Type represents the type of message, which could be Request/Notify/Response/Push
Expand Down Expand Up @@ -54,8 +55,9 @@ var types = map[Type]string{
}

var (
routes = make(map[string]uint16) // route map to code
codes = make(map[uint16]string) // code map to route
routesCodesMutex = sync.RWMutex{}
routes = make(map[string]uint16) // route map to code
codes = make(map[uint16]string) // code map to route
)

// Errors that could be occurred in message codec
Expand Down Expand Up @@ -110,6 +112,8 @@ func SetDictionary(dict map[string]uint16) error {
if dict == nil {
return nil
}
routesCodesMutex.Lock()
defer routesCodesMutex.Unlock()

for route, code := range dict {
r := strings.TrimSpace(route)
Expand All @@ -133,7 +137,13 @@ func SetDictionary(dict map[string]uint16) error {

// GetDictionary gets the routes map which is used to compress route.
func GetDictionary() map[string]uint16 {
return routes
routesCodesMutex.RLock()
defer routesCodesMutex.RUnlock()
dict := make(map[string]uint16)
for k, v := range routes {
dict[k] = v
}
return dict
}

func (t *Type) String() string {
Expand Down
4 changes: 4 additions & 0 deletions conn/message/message_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ func (me *MessagesEncoder) Encode(message *Message) ([]byte, error) {
buf := make([]byte, 0)
flag := byte(message.Type) << 1

routesCodesMutex.RLock()
code, compressed := routes[message.Route]
routesCodesMutex.RUnlock()
if compressed {
flag |= msgRouteCompressMask
}
Expand Down Expand Up @@ -163,7 +165,9 @@ func Decode(data []byte) (*Message, error) {
if flag&msgRouteCompressMask == 1 {
m.compressed = true
code := binary.BigEndian.Uint16(data[offset:(offset + 2)])
routesCodesMutex.RLock()
route, ok := codes[code]
routesCodesMutex.RUnlock()
if !ok {
return nil, ErrRouteInfoNotFound
}
Expand Down
41 changes: 40 additions & 1 deletion conn/message/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package message
import (
"errors"
"flag"
"fmt"
"path/filepath"
"testing"

Expand All @@ -14,6 +15,8 @@ var update = flag.Bool("update", false, "update .golden files")

func resetDicts(t *testing.T) {
t.Helper()
routesCodesMutex.Lock()
defer routesCodesMutex.Unlock()
routes = make(map[string]uint16)
codes = make(map[uint16]string)
}
Expand Down Expand Up @@ -156,7 +159,7 @@ var dictTables = map[string]struct {
map[uint16]string{1: "a"}, errors.New("duplicated route(route: b, code: 1)")},
}

func TestSetDictionaty(t *testing.T) {
func TestSetDictionary(t *testing.T) {
for name, table := range dictTables {
t.Run(name, func(t *testing.T) {
for _, dict := range table.dicts {
Expand All @@ -170,3 +173,39 @@ func TestSetDictionaty(t *testing.T) {
})
}
}

func TestSetDictionaryRace(t *testing.T) {
defer resetDicts(t)

done := make(chan bool, 2)

setDictRace := func(dict map[string]uint16) {
assert.Nil(t, SetDictionary(dict))
done <- true
}

go setDictRace(map[string]uint16{"a": 1})
go setDictRace(map[string]uint16{"b": 2})

// wait for both setDictRace to finish
<-done
<-done

expected_codes := map[uint16]string{1: "a", 2: "b"}
assert.EqualValues(t, expected_codes, codes)

expected_routes := map[string]uint16{"a": 1, "b": 2}
assert.EqualValues(t, expected_routes, routes)
}

func TestGetDictionary(t *testing.T) {
defer resetDicts(t)
expected := map[string]uint16{"a": 1, "b": 2}
assert.Nil(t, SetDictionary(expected))

dict := GetDictionary()
assert.Equal(t, expected, dict)

// make sure we're copying the routes maps
assert.NotEqual(t, fmt.Sprintf("%p", routes), fmt.Sprintf("%p", dict))
}

0 comments on commit dbe567b

Please sign in to comment.