forked from hashicorp/raft
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inmem_transport.go
157 lines (133 loc) · 3.56 KB
/
inmem_transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package raft
import (
"fmt"
"io"
"net"
"sync"
"time"
)
// Implements the net.Addr interface
type InmemAddr struct {
Id string
}
// NewInmemAddr returns a new in-memory addr with
// a randomly generate UUID as the ID
func NewInmemAddr() *InmemAddr {
return &InmemAddr{generateUUID()}
}
func (ia *InmemAddr) Network() string {
return "inmem"
}
func (ia *InmemAddr) String() string {
return ia.Id
}
// Implements the Transport interface to allow Raft to be tested
// in-memory without going over a network
type InmemTransport struct {
sync.RWMutex
consumerCh chan RPC
localAddr *InmemAddr
peers map[string]*InmemTransport
timeout time.Duration
}
// NewInmemTransport is used to initialize a new transport
// and generates a random local address.
func NewInmemTransport() (*InmemAddr, *InmemTransport) {
addr := NewInmemAddr()
trans := &InmemTransport{
consumerCh: make(chan RPC, 16),
localAddr: addr,
peers: make(map[string]*InmemTransport),
timeout: 50 * time.Millisecond,
}
return addr, trans
}
func (i *InmemTransport) Consumer() <-chan RPC {
return i.consumerCh
}
func (i *InmemTransport) LocalAddr() net.Addr {
return i.localAddr
}
func (i *InmemTransport) AppendEntries(target net.Addr, args *AppendEntriesRequest, resp *AppendEntriesResponse) error {
rpcResp, err := i.makeRPC(target, args, nil, i.timeout)
if err != nil {
return err
}
// Copy the result back
out := rpcResp.Response.(*AppendEntriesResponse)
*resp = *out
return nil
}
func (i *InmemTransport) RequestVote(target net.Addr, args *RequestVoteRequest, resp *RequestVoteResponse) error {
rpcResp, err := i.makeRPC(target, args, nil, i.timeout)
if err != nil {
return err
}
// Copy the result back
out := rpcResp.Response.(*RequestVoteResponse)
*resp = *out
return nil
}
func (i *InmemTransport) InstallSnapshot(target net.Addr, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error {
rpcResp, err := i.makeRPC(target, args, data, 10*i.timeout)
if err != nil {
return err
}
// Copy the result back
out := rpcResp.Response.(*InstallSnapshotResponse)
*resp = *out
return nil
}
func (i *InmemTransport) makeRPC(target net.Addr, args interface{}, r io.Reader, timeout time.Duration) (rpcResp RPCResponse, err error) {
i.RLock()
peer, ok := i.peers[target.String()]
i.RUnlock()
if !ok {
err = fmt.Errorf("Failed to connect to peer: %v", target)
return
}
// Send the RPC over
respCh := make(chan RPCResponse)
peer.consumerCh <- RPC{
Command: args,
Reader: r,
RespChan: respCh,
}
// Wait for a response
select {
case rpcResp = <-respCh:
if rpcResp.Error != nil {
err = rpcResp.Error
}
case <-time.After(timeout):
err = fmt.Errorf("command timed out")
}
return
}
// Use the UUID as the address directly
func (i *InmemTransport) EncodePeer(p net.Addr) []byte {
return []byte(p.String())
}
// Wrap the UUID in an InmemAddr
func (i *InmemTransport) DecodePeer(buf []byte) net.Addr {
return &InmemAddr{string(buf)}
}
// Connect is used to connect this transport to another transport for
// a given peer name. This allows for local routing.
func (i *InmemTransport) Connect(peer net.Addr, trans *InmemTransport) {
i.Lock()
defer i.Unlock()
i.peers[peer.String()] = trans
}
// Disconnect is used to remove the ability to route to a given peer
func (i *InmemTransport) Disconnect(peer net.Addr) {
i.Lock()
defer i.Unlock()
delete(i.peers, peer.String())
}
// DisconnectAll is used to remove all routes to peers
func (i *InmemTransport) DisconnectAll() {
i.Lock()
defer i.Unlock()
i.peers = make(map[string]*InmemTransport)
}