From 7f0d41e57bcf770a6ab18595d011b3e003d123ae Mon Sep 17 00:00:00 2001
From: Andrii Ursulenko <a.ursulenko@gmail.com>
Date: Thu, 26 Sep 2019 12:41:03 +0300
Subject: [PATCH 1/2] add setter to address field of ID and add lock to make it
 concurrently updatable + fix race conditions

---
 skademlia/client.go      | 14 ++++++--------
 skademlia/client_test.go | 18 ++++++++++++++----
 skademlia/identity.go    | 30 ++++++++++++++++++++++--------
 skademlia/protocol.go    |  6 +++---
 skademlia/table.go       |  3 +--
 skademlia/table_test.go  | 12 ++++++++----
 6 files changed, 54 insertions(+), 29 deletions(-)

diff --git a/skademlia/client.go b/skademlia/client.go
index 8560f798..4a41a118 100644
--- a/skademlia/client.go
+++ b/skademlia/client.go
@@ -136,7 +136,7 @@ func (c *Client) ClosestPeers(opts ...DialOption) []*grpc.ClientConn {
 	var conns []*grpc.ClientConn
 
 	for i := range ids {
-		if conn, err := c.Dial(ids[i].address, opts...); err == nil {
+		if conn, err := c.Dial(ids[i].Address(), opts...); err == nil {
 			conns = append(conns, conn)
 		}
 	}
@@ -179,7 +179,7 @@ func (c *Client) Dial(addr string, opts ...DialOption) (*grpc.ClientConn, error)
 }
 
 func (c *Client) DialContext(ctx context.Context, addr string) (*grpc.ClientConn, error) {
-	if addr == c.table.self.address {
+	if addr == c.table.self.Address() {
 		return nil, errors.New("attempted to dial self")
 	}
 
@@ -318,10 +318,8 @@ func (c *Client) Bootstrap() (results []*ID) {
 }
 
 func (c *Client) FindNode(target *ID, k int, a int, d int) (results []*ID) {
-	type request ID
-
 	type response struct {
-		requester *request
+		requester *ID
 		ids       []*ID
 	}
 
@@ -344,14 +342,14 @@ func (c *Client) FindNode(target *ID, k int, a int, d int) (results []*ID) {
 
 	for _, lookup := range lookups { // Perform d parallel disjoint lookups.
 		go func(lookup queue.Queue) {
-			requests := make(chan *request, a)
+			requests := make(chan *ID, a)
 			responses := make(chan *response, a)
 
 			for i := 0; i < a; i++ { // Perform α queries in parallel per disjoint lookup.
 				go func() {
 					for id := range requests {
 						f := func() error {
-							conn, err := c.Dial(id.address, WithTimeout(3*time.Second))
+							conn, err := c.Dial(id.Address(), WithTimeout(3*time.Second))
 
 							if err != nil {
 								responses <- nil
@@ -395,7 +393,7 @@ func (c *Client) FindNode(target *ID, k int, a int, d int) (results []*ID) {
 
 			for lookup.Len() > 0 || pending > 0 {
 				for lookup.Len() > 0 && len(requests) < cap(requests) {
-					requests <- (*request)(lookup.PopFront().(*ID))
+					requests <- lookup.PopFront().(*ID)
 					pending++
 				}
 
diff --git a/skademlia/client_test.go b/skademlia/client_test.go
index 2fcc4583..bee80793 100644
--- a/skademlia/client_test.go
+++ b/skademlia/client_test.go
@@ -100,20 +100,30 @@ func TestClientEviction(t *testing.T) {
 		l net.Listener
 	}
 
-	var peers []*peer
+	var peers = struct {
+		peers []*peer
+		sync.RWMutex
+	}{}
+
 	var wg sync.WaitGroup
 
 	for i := 0; i < 3; i++ {
 		wg.Add(1)
 
 		c, lis := getClient(t, 1, 1)
-		peers = append(peers, &peer{
+
+		peers.Lock()
+		peers.peers = append(peers.peers, &peer{
 			c: c,
 			l: lis,
 		})
+		peers.Unlock()
 
 		go func(i int) {
-			p := peers[i]
+			peers.RLock()
+			p := peers.peers[i]
+			peers.RUnlock()
+
 			s := p.c.Listen()
 
 			wg.Done()
@@ -124,7 +134,7 @@ func TestClientEviction(t *testing.T) {
 
 	wg.Wait()
 
-	for _, p := range peers {
+	for _, p := range peers.peers {
 		_, _ = client.Dial(p.l.Addr().String())
 	}
 
diff --git a/skademlia/identity.go b/skademlia/identity.go
index 1db37d8f..c971842b 100644
--- a/skademlia/identity.go
+++ b/skademlia/identity.go
@@ -26,12 +26,15 @@ import (
 	"github.com/pkg/errors"
 	"golang.org/x/crypto/blake2b"
 	"io"
+	"sync"
 )
 
 type ID struct {
 	address   string
 	publicKey edwards25519.PublicKey
 
+	sync.RWMutex
+
 	id, checksum, nonce [blake2b.Size256]byte
 }
 
@@ -49,8 +52,11 @@ func NewID(address string, publicKey edwards25519.PublicKey, nonce [blake2b.Size
 	}
 }
 
-func (m ID) Address() string {
-	return m.address
+func (m *ID) Address() string {
+	m.RLock()
+	t := m.address
+	m.RUnlock()
+	return t
 }
 
 func (m ID) PublicKey() edwards25519.PublicKey {
@@ -66,16 +72,24 @@ func (m ID) Nonce() [blake2b.Size256]byte {
 }
 
 func (m ID) String() string {
-	return fmt.Sprintf("%s[%x]", m.address, m.publicKey)
+	return fmt.Sprintf("%s[%x]", (&m).Address(), m.publicKey)
+}
+
+func (m *ID) SetAddress(address string) {
+	m.Lock()
+	m.address = address
+	m.Unlock()
 }
 
 func (m ID) Marshal() []byte {
-	buf := make([]byte, 2+len(m.address)+edwards25519.SizePublicKey+blake2b.Size256)
+	address := (&m).Address()
+
+	buf := make([]byte, 2+len(address)+edwards25519.SizePublicKey+blake2b.Size256)
 
-	binary.BigEndian.PutUint16(buf[0:2], uint16(len(m.address)))
-	copy(buf[2:2+len(m.address)], m.address)
-	copy(buf[2+len(m.address):2+len(m.address)+edwards25519.SizePublicKey], m.publicKey[:])
-	copy(buf[2+len(m.address)+edwards25519.SizePublicKey:2+len(m.address)+edwards25519.SizePublicKey+blake2b.Size256], m.nonce[:])
+	binary.BigEndian.PutUint16(buf[0:2], uint16(len(address)))
+	copy(buf[2:2+len(address)], address)
+	copy(buf[2+len(address):2+len(address)+edwards25519.SizePublicKey], m.publicKey[:])
+	copy(buf[2+len(address)+edwards25519.SizePublicKey:2+len(address)+edwards25519.SizePublicKey+blake2b.Size256], m.nonce[:])
 
 	return buf
 }
diff --git a/skademlia/protocol.go b/skademlia/protocol.go
index d71f1abf..4dd5aa97 100644
--- a/skademlia/protocol.go
+++ b/skademlia/protocol.go
@@ -87,7 +87,7 @@ func (p Protocol) handshake(info noise.Info, conn net.Conn) (*ID, error) {
 		bucket.Unlock()
 
 		p.client.peersLock.RLock()
-		lastConn, exists := p.client.peers[lastID.address]
+		lastConn, exists := p.client.peers[lastID.Address()]
 		p.client.peersLock.RUnlock()
 
 		if !exists {
@@ -109,7 +109,7 @@ func (p Protocol) handshake(info noise.Info, conn net.Conn) (*ID, error) {
 		// Ping was successful; disallow the current peer from connecting.
 
 		p.client.peersLock.Lock()
-		delete(p.client.peers, id.address)
+		delete(p.client.peers, id.Address())
 		p.client.peersLock.Unlock()
 
 		return nil, errors.New("skademlia: cannot evict any peers to make room for new peer")
@@ -196,7 +196,7 @@ func (p Protocol) Server(info noise.Info, conn net.Conn) (net.Conn, error) {
 	p.client.logger.Printf("Client %s has connected to you.\n", id)
 
 	go func() {
-		if _, err = p.client.Dial(id.address, WithTimeout(3*time.Second)); err != nil {
+		if _, err = p.client.Dial(id.Address(), WithTimeout(3*time.Second)); err != nil {
 			_ = conn.Close()
 		}
 	}()
diff --git a/skademlia/table.go b/skademlia/table.go
index faf90832..d53917af 100644
--- a/skademlia/table.go
+++ b/skademlia/table.go
@@ -111,8 +111,7 @@ func (t *Table) Update(target *ID) error {
 		b.Lock()
 
 		// address might differ for same public key (checksum
-		id := found.Value.(*ID)
-		id.address = target.address
+		found.Value.(*ID).SetAddress(target.address)
 
 		b.MoveToFront(found)
 		b.Unlock()
diff --git a/skademlia/table_test.go b/skademlia/table_test.go
index 01b85aee..ceaeea7c 100644
--- a/skademlia/table_test.go
+++ b/skademlia/table_test.go
@@ -172,9 +172,13 @@ func TestUpdateSamePublicKey(t *testing.T) {
 	// we create new id with same public key but different address
 	addressToChange := "127.0.0.3"
 	updatedCopy := NewID(addressToChange, updated.publicKey, updated.nonce)
-	if !assert.NoError(t, table.Update(updatedCopy)) {
-		return
-	}
+
+	// ensure that updating table (and id) safe to be done concurrently
+	go func() {
+		if !assert.NoError(t, table.Update(updatedCopy)) {
+			return
+		}
+	}()
 
 	found := table.FindClosest(rootID, 10)
 	if !assert.Equal(t, 1, len(found)) {
@@ -183,5 +187,5 @@ func TestUpdateSamePublicKey(t *testing.T) {
 
 	// we expect id in the table to have same public key but updated address
 	assert.Equal(t, updated.publicKey, found[0].publicKey)
-	assert.Equal(t, addressToChange, found[0].address)
+	assert.Equal(t, addressToChange, found[0].Address())
 }

From 465db066f38c1e75fbbae42ef74da3786ed1dd3a Mon Sep 17 00:00:00 2001
From: Andrii Ursulenko <a.ursulenko@gmail.com>
Date: Thu, 26 Sep 2019 13:33:11 +0300
Subject: [PATCH 2/2] ensure update happens

---
 skademlia/table_test.go | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/skademlia/table_test.go b/skademlia/table_test.go
index ceaeea7c..1c85f0a3 100644
--- a/skademlia/table_test.go
+++ b/skademlia/table_test.go
@@ -24,6 +24,7 @@ import (
 	"github.com/perlin-network/noise/edwards25519"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/crypto/blake2b"
+	"sync"
 	"testing"
 	"testing/quick"
 )
@@ -173,8 +174,11 @@ func TestUpdateSamePublicKey(t *testing.T) {
 	addressToChange := "127.0.0.3"
 	updatedCopy := NewID(addressToChange, updated.publicKey, updated.nonce)
 
+	wg := sync.WaitGroup{}
 	// ensure that updating table (and id) safe to be done concurrently
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		if !assert.NoError(t, table.Update(updatedCopy)) {
 			return
 		}
@@ -185,6 +189,8 @@ func TestUpdateSamePublicKey(t *testing.T) {
 		return
 	}
 
+	wg.Wait()
+
 	// we expect id in the table to have same public key but updated address
 	assert.Equal(t, updated.publicKey, found[0].publicKey)
 	assert.Equal(t, addressToChange, found[0].Address())