Skip to content

Commit

Permalink
Update UDPMux to dispatch inbound packets on ufrag
Browse files Browse the repository at this point in the history
If we get a packet for an address we don't know dispatch it
by ufrag still.

Resolves pion#357
  • Loading branch information
Sean-Der committed Apr 21, 2021
1 parent af8539d commit e6e49f5
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 18 deletions.
4 changes: 3 additions & 1 deletion agent_udpmux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ func TestMuxAgent(t *testing.T) {
muxedA, err := NewAgent(&AgentConfig{
UDPMux: udpMux,
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(),
NetworkTypes: []NetworkType{
NetworkTypeUDP4,
},
})
require.NoError(t, err)

Expand Down
74 changes: 57 additions & 17 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"io"
"net"
"os"
"strings"
"sync"

"github.com/pion/logging"
"github.com/pion/stun"
)

// UDPMux allows multiple connections to go over a single UDP port
Expand All @@ -27,8 +29,8 @@ type UDPMuxDefault struct {
// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn

// map of udpAddr -> udpMuxedConn
addressMap sync.Map
addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn

// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
Expand All @@ -47,6 +49,7 @@ type UDPMuxParams struct {
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
Expand Down Expand Up @@ -109,10 +112,13 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()

m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()

for _, c := range removedConns {
addresses := c.getAddresses()
for _, addr := range addresses {
m.addressMap.Delete(addr)
delete(m.addressMap, addr)
}
}
}
Expand Down Expand Up @@ -154,9 +160,12 @@ func (m *UDPMuxDefault) removeConn(key string) {
return
}

m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()

addresses := c.getAddresses()
for _, addr := range addresses {
m.addressMap.Delete(addr)
delete(m.addressMap, addr)
}
}

Expand All @@ -168,11 +177,16 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
if m.IsClosed() {
return
}
existing, ok := m.addressMap.Load(addr)

m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()

existing, ok := m.addressMap[addr]
if ok {
existing.(*udpMuxedConn).removeAddress(addr)
existing.removeAddress(addr)
}
m.addressMap.Store(addr, conn)

m.addressMap[addr] = conn
}

func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
Expand All @@ -192,6 +206,7 @@ func (m *UDPMuxDefault) connWorker() {
defer func() {
_ = m.Close()
}()

buf := make([]byte, receiveMTU)
for {
n, addr, err := m.params.UDPConn.ReadFrom(buf)
Expand All @@ -207,21 +222,46 @@ func (m *UDPMuxDefault) connWorker() {
return
}

// look up forward destination
v, ok := m.addressMap.Load(addr.String())
if !ok {
// ignore packets that we don't know where to route to
continue
}

udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
logger.Errorf("underlying PacketConn did not return a UDPAddr")
return
}
c := v.(*udpMuxedConn)
err = c.writePacket(buf[:n], udpAddr)
if err != nil {

// If we have already seen this address dispatch to the appropriate destination
m.addressMapMu.Lock()
destinationConn := m.addressMap[addr.String()]
m.addressMapMu.Unlock()

// If we haven't seen this address before but is a STUN packet lookup by ufrag
if destinationConn == nil && stun.IsMessage(buf[:n]) {
msg := &stun.Message{
Raw: append([]byte{}, buf[:n]...),
}

if err = msg.Decode(); err != nil {
m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err)
continue
}

attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr != nil {
m.params.Logger.Warnf("No Username attribute in STUN message from %s\n", addr.String())
continue
}

ufrag := strings.Split(string(attr), ":")[0]

m.mu.Lock()
destinationConn = m.conns[ufrag]
m.mu.Unlock()
}

if destinationConn == nil {
continue
}

if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil {
logger.Errorf("could not write packet: %v", err)
}
}
Expand Down
1 change: 1 addition & 0 deletions udp_muxed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ func (c *udpMuxedConn) addAddress(addr string) {
func (c *udpMuxedConn) removeAddress(addr string) {
c.mu.Lock()
defer c.mu.Unlock()

newAddresses := make([]string, 0, len(c.addresses))
for _, a := range c.addresses {
if a != addr {
Expand Down

0 comments on commit e6e49f5

Please sign in to comment.