Skip to content

Commit

Permalink
Fix: tunnel UDP race condition (#1043)
Browse files Browse the repository at this point in the history
  • Loading branch information
xjasonlyu authored Oct 28, 2020
1 parent ba060bd commit 2cd1b89
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 44 deletions.
6 changes: 3 additions & 3 deletions component/nat/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ func (t *Table) Get(key string) C.PacketConn {
return item.(C.PacketConn)
}

func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) {
item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{})
return item.(*sync.WaitGroup), loaded
func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
item, loaded := t.mapping.LoadOrStore(key, sync.NewCond(&sync.Mutex{}))
return item.(*sync.Cond), loaded
}

func (t *Table) Delete(key string) {
Expand Down
90 changes: 49 additions & 41 deletions tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
return
}

// make a fAddr if requset ip is fakeip
// make a fAddr if request ip is fakeip
var fAddr net.Addr
if resolver.IsExistFakeIP(metadata.DstIP) {
fAddr = metadata.UDPAddr()
Expand All @@ -176,57 +176,65 @@ func handleUDPConn(packet *inbound.PacketAdapter) {
}

key := packet.LocalAddr().String()
pc := natTable.Get(key)
if pc != nil {
handleUDPToRemote(packet, pc, metadata)

handle := func() bool {
pc := natTable.Get(key)
if pc != nil {
handleUDPToRemote(packet, pc, metadata)
return true
}
return false
}

if handle() {
return
}

lockKey := key + "-lock"
wg, loaded := natTable.GetOrCreateLock(lockKey)
cond, loaded := natTable.GetOrCreateLock(lockKey)

go func() {
if !loaded {
wg.Add(1)
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
natTable.Delete(lockKey)
wg.Done()
return
}

rawPc, err := proxy.DialUDP(metadata)
if err != nil {
log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
natTable.Delete(lockKey)
wg.Done()
return
}
pc = newUDPTracker(rawPc, DefaultManager, metadata, rule)

switch true {
case rule != nil:
log.Infoln("[UDP] %s --> %v match %s(%s) using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rule.Payload(), rawPc.Chains().String())
case mode == Global:
log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String())
case mode == Direct:
log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String())
default:
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
}
if loaded {
cond.L.Lock()
cond.Wait()
handle()
cond.L.Unlock()
return
}

natTable.Set(key, pc)
defer func() {
natTable.Delete(lockKey)
wg.Done()
go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr)
cond.Broadcast()
}()

proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return
}

wg.Wait()
pc := natTable.Get(key)
if pc != nil {
handleUDPToRemote(packet, pc, metadata)
rawPc, err := proxy.DialUDP(metadata)
if err != nil {
log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
return
}
pc := newUDPTracker(rawPc, DefaultManager, metadata, rule)

switch true {
case rule != nil:
log.Infoln("[UDP] %s --> %v match %s(%s) using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rule.Payload(), rawPc.Chains().String())
case mode == Global:
log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String())
case mode == Direct:
log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String())
default:
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
}

go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr)

natTable.Set(key, pc)
handle()
}()
}

Expand Down

0 comments on commit 2cd1b89

Please sign in to comment.