Skip to content

Commit

Permalink
fix data race
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed Nov 21, 2018
1 parent e9b872c commit a020c7b
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 99 deletions.
34 changes: 27 additions & 7 deletions bypass.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ type Bypass struct {
matchers []Matcher
reversed bool
period time.Duration // the period for live reloading
mux sync.Mutex
mux sync.RWMutex
}

// NewBypass creates and initializes a new Bypass using matchers as its match rules.
Expand Down Expand Up @@ -160,8 +160,8 @@ func (bp *Bypass) Contains(addr string) bool {
}
}

bp.mux.Lock()
defer bp.mux.Unlock()
bp.mux.RLock()
defer bp.mux.RUnlock()

var matched bool
for _, matcher := range bp.matchers {
Expand All @@ -179,22 +179,33 @@ func (bp *Bypass) Contains(addr string) bool {

// AddMatchers appends matchers to the bypass matcher list.
func (bp *Bypass) AddMatchers(matchers ...Matcher) {
bp.mux.Lock()
defer bp.mux.Unlock()

bp.matchers = append(bp.matchers, matchers...)
}

// Matchers return the bypass matcher list.
func (bp *Bypass) Matchers() []Matcher {
bp.mux.RLock()
defer bp.mux.RUnlock()

return bp.matchers
}

// Reversed reports whether the rules of the bypass are reversed.
func (bp *Bypass) Reversed() bool {
bp.mux.RLock()
defer bp.mux.RUnlock()

return bp.reversed
}

// Reload parses config from r, then live reloads the bypass.
func (bp *Bypass) Reload(r io.Reader) error {
var matchers []Matcher
var period time.Duration
var reversed bool

scanner := bufio.NewScanner(r)
for scanner.Scan() {
Expand All @@ -217,7 +228,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
}
}
if len(ss) == 2 {
bp.period, _ = time.ParseDuration(ss[1])
period, _ = time.ParseDuration(ss[1])
continue
}
}
Expand All @@ -231,7 +242,7 @@ func (bp *Bypass) Reload(r io.Reader) error {
}
}
if len(ss) == 2 {
bp.reversed, _ = strconv.ParseBool(ss[1])
reversed, _ = strconv.ParseBool(ss[1])
continue
}
}
Expand All @@ -247,19 +258,28 @@ func (bp *Bypass) Reload(r io.Reader) error {
defer bp.mux.Unlock()

bp.matchers = matchers
bp.period = period
bp.reversed = reversed

return nil
}

// Period returns the reload period
func (bp *Bypass) Period() time.Duration {
bp.mux.RLock()
defer bp.mux.RUnlock()

return bp.period
}

func (bp *Bypass) String() string {
bp.mux.RLock()
defer bp.mux.RUnlock()

b := &bytes.Buffer{}
fmt.Fprintf(b, "reversed: %v\n", bp.Reversed())
for _, m := range bp.Matchers() {
fmt.Fprintf(b, "reversed: %v\n", bp.reversed)
fmt.Fprintf(b, "reload: %v\n", bp.period)
for _, m := range bp.matchers {
b.WriteString(m.String())
b.WriteByte('\n')
}
Expand Down
18 changes: 9 additions & 9 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func newRoute(nodes ...Node) *Chain {
}

// Nodes returns the proxy nodes that the chain holds.
// If a node is a node group, the first node in the group will be returned.
// The first node in each group will be returned.
func (c *Chain) Nodes() (nodes []Node) {
for _, group := range c.nodeGroups {
if ns := group.Nodes(); len(ns) > 0 {
Expand All @@ -61,7 +61,7 @@ func (c *Chain) LastNode() Node {
return Node{}
}
group := c.nodeGroups[len(c.nodeGroups)-1]
return group.nodes[0].Clone()
return group.GetNode(0)
}

// LastNodeGroup returns the last group of the group list.
Expand Down Expand Up @@ -173,7 +173,6 @@ func (c *Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
}

// Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
Expand Down Expand Up @@ -206,6 +205,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
}

// getConn obtains a connection to the last node of the chain.
// It does not handshake with the last node.
func (c *Chain) getConn() (conn net.Conn, err error) {
if c.IsEmpty() {
err = ErrEmptyChain
Expand All @@ -216,33 +216,33 @@ func (c *Chain) getConn() (conn net.Conn, err error) {

cn, err := node.Client.Dial(node.Addr, node.DialOptions...)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}

cn, err = node.Client.Handshake(cn, node.HandshakeOptions...)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
node.ResetDead()
node.group.ResetDeadNode(node.ID)

preNode := node
for _, node := range nodes[1:] {
var cc net.Conn
cc, err = preNode.Client.Connect(cn, node.Addr)
if err != nil {
cn.Close()
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
cc, err = node.Client.Handshake(cc, node.HandshakeOptions...)
if err != nil {
cn.Close()
node.MarkDead()
node.group.MarkDeadNode(node.ID)
return
}
node.ResetDead()
node.group.ResetDeadNode(node.ID)

cn = cc
preNode = node
Expand Down
3 changes: 2 additions & 1 deletion cmd/gost/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
strategy = s
}
}
group.Options = append([]gost.SelectOption{},
group.SetSelector(
nil,
gost.WithFilter(&gost.FailFilter{
MaxFails: cfg.MaxFails,
FailTimeout: time.Duration(cfg.FailTimeout) * time.Second,
Expand Down
20 changes: 10 additions & 10 deletions forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
)
if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), node.Addr, err)
node.MarkDead()
node.group.MarkDeadNode(node.ID)
} else {
break
}
Expand All @@ -116,7 +116,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
return
}

node.ResetDead()
node.group.ResetDeadNode(node.ID)
defer cc.Close()

log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), node.Addr)
Expand Down Expand Up @@ -191,13 +191,13 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
if h.options.Chain.IsEmpty() {
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[udp] %s - %s : %s", conn.LocalAddr(), node.Addr, err)
return
}
Expand All @@ -212,7 +212,7 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
}

defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)

log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)
Expand Down Expand Up @@ -291,7 +291,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout)
if err != nil {
log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err)
node.MarkDead()
node.group.MarkDeadNode(node.ID)
} else {
break
}
Expand All @@ -301,7 +301,7 @@ func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) {
}

defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)

log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr)
transport(cc, conn)
Expand Down Expand Up @@ -369,18 +369,18 @@ func (h *udpRemoteForwardHandler) Handle(conn net.Conn) {

raddr, err := net.ResolveUDPAddr("udp", node.Addr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
cc, err := net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
node.group.MarkDeadNode(node.ID)
log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
defer cc.Close()
node.ResetDead()
node.group.ResetDeadNode(node.ID)

log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr)
transport(conn, cc)
Expand Down
19 changes: 18 additions & 1 deletion hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net"
"strings"
"sync"
"time"

"github.com/go-log/log"
Expand All @@ -25,6 +26,7 @@ type Host struct {
type Hosts struct {
hosts []Host
period time.Duration
mux sync.RWMutex
}

// NewHosts creates a Hosts with optional list of host
Expand All @@ -36,6 +38,9 @@ func NewHosts(hosts ...Host) *Hosts {

// AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) {
h.mux.Lock()
defer h.mux.Unlock()

h.hosts = append(h.hosts, host...)
}

Expand All @@ -44,6 +49,10 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {
if h == nil {
return
}

h.mux.RLock()
defer h.mux.RUnlock()

for _, h := range h.hosts {
if h.Hostname == host {
ip = h.IP
Expand All @@ -64,6 +73,7 @@ func (h *Hosts) Lookup(host string) (ip net.IP) {

// Reload parses config from r, then live reloads the hosts.
func (h *Hosts) Reload(r io.Reader) error {
var period time.Duration
var hosts []Host

scanner := bufio.NewScanner(r)
Expand All @@ -89,7 +99,7 @@ func (h *Hosts) Reload(r io.Reader) error {

// reload option
if strings.ToLower(ss[0]) == "reload" {
h.period, _ = time.ParseDuration(ss[1])
period, _ = time.ParseDuration(ss[1])
continue
}

Expand All @@ -110,11 +120,18 @@ func (h *Hosts) Reload(r io.Reader) error {
return err
}

h.mux.Lock()
h.period = period
h.hosts = hosts
h.mux.Unlock()

return nil
}

// Period returns the reload period
func (h *Hosts) Period() time.Duration {
h.mux.RLock()
defer h.mux.RUnlock()

return h.period
}
2 changes: 1 addition & 1 deletion http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {

u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if Debug && (u != "" || p != "") {
log.Logf("[http] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
log.Logf("[http2] %s - %s : Authorization: '%s' '%s'", r.RemoteAddr, target, u, p)
}
if !authenticate(u, p, h.options.Users...) {
// probing resistance is enabled
Expand Down
Loading

0 comments on commit a020c7b

Please sign in to comment.