Skip to content

Commit

Permalink
Fix golangci-lint and add more unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyLc authored and stv0g committed May 19, 2023
1 parent 03cd97b commit b6b3453
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 67 deletions.
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ func (c *Client) handleSTUNMessage(data []byte, from net.Addr) error {
}

allocation.HandleConnectionAttempt(addr, cid)
default:
c.log.Debug("received unsupported STUN method")
}
return nil
}
Expand Down
4 changes: 0 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,6 @@ func TestTCPClient(t *testing.T) {
allocation, err := client.AllocateTCP()
require.NoError(t, err)

// TODO: Implement server side handling of Connect and ConnectionBind
// _, err = allocation.Dial(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080})
// assert.NoError(t, err)

// Shutdown
require.NoError(t, allocation.Close())
require.NoError(t, conn.Close())
Expand Down
41 changes: 21 additions & 20 deletions examples/turn-client/tcp-alloc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ func setupSignalingChannel(addrCh chan string, signaling bool, relayAddr string)
addr := "127.0.0.1:5000"
if signaling {
go func() {
listen, err := net.Listen("tcp", addr)
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Panicf("Failed to create signaling server: %s", err)
}
defer listen.Close()
defer listener.Close() //nolint:errcheck,gosec
for {
conn, err := listen.Accept()
conn, err := listener.Accept()
if err != nil {
log.Panicf("Failed to accept: %s", err)
}

go func() {
message, err := bufio.NewReader(conn).ReadString('\n')
var message string
message, err = bufio.NewReader(conn).ReadString('\n')
if err != nil {
log.Panicf("Failed to read relayAddr: %s", err)
log.Panicf("Failed to read from relayAddr: %s", err)
}
addrCh <- message[:len(message)-1]
}()
Expand Down Expand Up @@ -70,19 +71,19 @@ func main() {
flag.Parse()

if len(*host) == 0 {
log.Fatalf("'host' is required")
log.Panicf("'host' is required")
}

if len(*user) == 0 {
log.Fatalf("'user' is required")
log.Panicf("'user' is required")
}

// Dial TURN Server
turnServerAddrStr := fmt.Sprintf("%s:%d", *host, *port)

turnServerAddr, err := net.ResolveTCPAddr("tcp", turnServerAddrStr)
if err != nil {
log.Fatalf("Failed to resolve TURN server address: %s", err)
log.Panicf("Failed to resolve TURN server address: %s", err)
}

conn, err := net.DialTCP("tcp", nil, turnServerAddr)
Expand Down Expand Up @@ -125,7 +126,7 @@ func main() {
}
defer func() {
if closeErr := allocation.Close(); closeErr != nil {
log.Fatalf("Failed to close connection: %s", closeErr)
log.Panicf("Failed to close connection: %s", closeErr)
}
}()

Expand All @@ -139,7 +140,7 @@ func main() {
peerAddrStr := <-addrCh
peerAddr, err := net.ResolveTCPAddr("tcp", peerAddrStr)
if err != nil {
log.Fatalf("Failed to resolve peer address: %s", err)
log.Panicf("Failed to resolve peer address: %s", err)
}

log.Printf("Received peer address: %s", peerAddrStr)
Expand All @@ -149,44 +150,44 @@ func main() {
if *signaling {
conn, err := allocation.DialTCP("tcp", nil, peerAddr)
if err != nil {
log.Fatalf("Failed to dial: %s", err)
log.Panicf("Failed to dial: %s", err)
}

if _, err := conn.Write([]byte("hello!")); err != nil {
log.Fatalf("Failed to write: %s", err)
if _, err = conn.Write([]byte("hello!")); err != nil {
log.Panicf("Failed to write: %s", err)
}

n, err = conn.Read(buf)
if err != nil {
log.Fatalf("Failed to read from relay connection: %s", err)
log.Panicf("Failed to read from relay connection: %s", err)
}

if err := conn.Close(); err != nil {
log.Fatalf("Failed to close: %s", err)
log.Panicf("Failed to close: %s", err)
}
} else {
if err := client.CreatePermission(peerAddr); err != nil {
log.Fatalf("Failed to create permission: %s", err)
log.Panicf("Failed to create permission: %s", err)
}

conn, err := allocation.AcceptTCP()
if err != nil {
log.Fatalf("Failed to accept TCP connection: %s", err)
log.Panicf("Failed to accept TCP connection: %s", err)
}

log.Printf("Accepted connection from: %s", conn.RemoteAddr())

n, err = conn.Read(buf)
if err != nil {
log.Fatalf("Failed to read from relay conn: %s", err)
log.Panicf("Failed to read from relay conn: %s", err)
}

if _, err := conn.Write([]byte("hello back!")); err != nil {
log.Fatalf("Failed to write: %s", err)
log.Panicf("Failed to write: %s", err)
}

if err := conn.Close(); err != nil {
log.Fatalf("Failed to close: %s", err)
log.Panicf("Failed to close: %s", err)
}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/client/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ var (
errFailedToBuildRefreshRequest = errors.New("failed to build refresh request")
errFailedToRefreshAllocation = errors.New("failed to refresh allocation")
errFailedToGetLifetime = errors.New("failed to get lifetime from refresh response")
errInvalidTURNAddress = errors.New("invalid TURN server address")
errUnexpectedSTUNRequestMessage = errors.New("unexpected STUN request message")
)

type timeoutError struct {
Expand Down
35 changes: 19 additions & 16 deletions internal/client/tcp_alloc.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func NewTCPAllocation(config *AllocationConfig) *TCPAllocation {
return a
}

// Connect sends a Connect request to the turn server and returns a chosen connection ID
func (a *TCPAllocation) Connect(peer net.Addr) (proto.ConnectionID, error) {
setters := []stun.Setter{
stun.TransactionID,
Expand All @@ -96,15 +97,16 @@ func (a *TCPAllocation) Connect(peer net.Addr) (proto.ConnectionID, error) {
if err != nil {
return 0, err
}

res := trRes.Msg

if res.Type.Class == stun.ClassErrorResponse {
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return 0, fmt.Errorf("%s (error %s)", res.Type, code)
return 0, fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}

return 0, fmt.Errorf("%s", res.Type)
return 0, fmt.Errorf("%s", res.Type) //nolint:goerr113
}

var cid proto.ConnectionID
Expand Down Expand Up @@ -140,13 +142,13 @@ func (a *TCPAllocation) DialWithConn(conn net.Conn, network, rAddrStr string) (*
// DialTCP acts like Dial for TCP networks.
func (a *TCPAllocation) DialTCP(network string, lAddr, rAddr *net.TCPAddr) (*TCPConn, error) {
var rAddrServer *net.TCPAddr
if addr, ok := a.client.TURNServerAddr().(*net.UDPAddr); ok {
if addr, ok := a.client.TURNServerAddr().(*net.TCPAddr); ok {
rAddrServer = &net.TCPAddr{
IP: addr.IP,
Port: addr.Port,
}
} else {
return nil, fmt.Errorf("invalid TURN server address")
return nil, errInvalidTURNAddress
}

conn, err := net.DialTCP(network, lAddr, rAddrServer)
Expand All @@ -156,14 +158,14 @@ func (a *TCPAllocation) DialTCP(network string, lAddr, rAddr *net.TCPAddr) (*TCP

dataConn, err := a.DialTCPWithConn(conn, network, rAddr)
if err != nil {
conn.Close() //nolint:errcheck
conn.Close() //nolint:errcheck,gosec
}

return dataConn, err
}

// DialTCPWithConn acts like DialWithConn for TCP networks.
func (a *TCPAllocation) DialTCPWithConn(conn net.Conn, network string, rAddr *net.TCPAddr) (*TCPConn, error) {
func (a *TCPAllocation) DialTCPWithConn(conn net.Conn, _ string, rAddr *net.TCPAddr) (*TCPConn, error) {
var err error

// Check if we have a permission for the destination IP addr
Expand All @@ -188,13 +190,14 @@ func (a *TCPAllocation) DialTCPWithConn(conn net.Conn, network string, rAddr *ne
return nil, err
}

tcpConn, ok := conn.(*net.TCPConn)
tcpConn, ok := conn.(transport.TCPConn)
if !ok {
return nil, errTCPAddrCast
}

dataConn := &TCPConn{
TCPConn: tcpConn,
ConnectionID: cid,
remoteAddress: rAddr,
allocation: a,
}
Expand Down Expand Up @@ -250,22 +253,22 @@ func (a *TCPAllocation) BindConnection(dataConn *TCPConn, cid proto.ConnectionID
return err
}
res := &stun.Message{Raw: raw}
if err := res.Decode(); err != nil {
if err = res.Decode(); err != nil {
return fmt.Errorf("failed to decode STUN message: %w", err)
}

switch res.Type.Class {
case stun.ClassErrorResponse:
var code stun.ErrorCodeAttribute
if err = code.GetFrom(res); err == nil {
return fmt.Errorf("%s (error %s)", res.Type, code)
return fmt.Errorf("%s (error %s)", res.Type, code) //nolint:goerr113
}
return fmt.Errorf("%s", res.Type)
return fmt.Errorf("%s", res.Type) //nolint:goerr113
case stun.ClassSuccessResponse:
a.log.Debug("Successful connectionBind request")
return nil
default:
return fmt.Errorf("unexpected STUN request message: %s", res.String())
return fmt.Errorf("%w: %s", errUnexpectedSTUNRequestMessage, res.String())
}
}

Expand All @@ -288,18 +291,18 @@ func (a *TCPAllocation) AcceptTCP() (transport.TCPConn, error) {

dataConn, err := a.AcceptTCPWithConn(tcpConn)
if err != nil {
tcpConn.Close() //nolint: errcheck
tcpConn.Close() //nolint:errcheck,gosec
}

return dataConn, err
}

// AcceptTCPWithConn accepts the next incoming call and returns the new connection.
func (a *TCPAllocation) AcceptTCPWithConn(conn net.Conn) (transport.TCPConn, error) {
func (a *TCPAllocation) AcceptTCPWithConn(conn net.Conn) (*TCPConn, error) {
select {
case attempt := <-a.connAttemptCh:

tcpConn, ok := conn.(*net.TCPConn)
tcpConn, ok := conn.(transport.TCPConn)
if !ok {
return nil, errTCPAddrCast
}
Expand All @@ -326,6 +329,7 @@ func (a *TCPAllocation) AcceptTCPWithConn(conn net.Conn) (transport.TCPConn, err
}
}

// SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
func (a *TCPAllocation) SetDeadline(t time.Time) error {
var d time.Duration
if t == noDeadline() {
Expand Down Expand Up @@ -355,10 +359,9 @@ func (a *TCPAllocation) Addr() net.Addr {

// HandleConnectionAttempt is called by the TURN client
// when it receives a ConnectionAttempt indication.
func (a *TCPAllocation) HandleConnectionAttempt(from *net.TCPAddr, cid proto.ConnectionID) error {
func (a *TCPAllocation) HandleConnectionAttempt(from *net.TCPAddr, cid proto.ConnectionID) {
a.connAttemptCh <- &connectionAttempt{
from: from,
cid: cid,
}
return nil
}
12 changes: 5 additions & 7 deletions internal/client/tcp_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package client
import (
"errors"
"net"
"time"

"github.com/pion/transport/v2"
"github.com/pion/turn/v2/internal/proto"
Expand All @@ -23,14 +22,13 @@ const (

var _ transport.TCPConn = (*TCPConn)(nil) // Includes type check for net.Conn

// TCPConn wraps a net.TCPConn and returns the allocations relayed
// TCPConn wraps a transport.TCPConn and returns the allocations relayed
// transport address in response to TCPConn.LocalAddress()
type TCPConn struct {
*net.TCPConn
remoteAddress *net.TCPAddr
allocation *TCPAllocation
acceptDeadline time.Duration
ConnectionID proto.ConnectionID
transport.TCPConn
remoteAddress *net.TCPAddr
allocation *TCPAllocation
ConnectionID proto.ConnectionID
}

type connectionAttempt struct {
Expand Down
Loading

0 comments on commit b6b3453

Please sign in to comment.