Skip to content

Commit

Permalink
use net.JoinHostPort instead of fmt.Sprintf (fatedier#2791)
Browse files Browse the repository at this point in the history
  • Loading branch information
fatedier authored Feb 9, 2022
1 parent b2311e5 commit 6194273
Show file tree
Hide file tree
Showing 17 changed files with 61 additions and 49 deletions.
16 changes: 6 additions & 10 deletions client/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,22 +347,18 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)

// Send detect message
array := strings.Split(natHoleRespMsg.VisitorAddr, ":")
if len(array) <= 1 {
xl.Error("get NatHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
host, portStr, err := net.SplitHostPort(natHoleRespMsg.VisitorAddr)
if err != nil {
xl.Error("get NatHoleResp visitor address [%s] error: %v", natHoleRespMsg.VisitorAddr, err)
}
laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String())
/*
for i := 1000; i < 65000; i++ {
pxy.sendDetectMsg(array[0], int64(i), laddr, "a")
}
*/
port, err := strconv.ParseInt(array[1], 10, 64)

port, err := strconv.ParseInt(portStr, 10, 64)
if err != nil {
xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
return
}
pxy.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid))
pxy.sendDetectMsg(host, int(port), laddr, []byte(natHoleRespMsg.Sid))
xl.Trace("send all detect msg done")

msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{})
Expand Down
7 changes: 4 additions & 3 deletions client/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"

Expand Down Expand Up @@ -85,7 +86,7 @@ type STCPVisitor struct {
}

func (sv *STCPVisitor) Run() (err error) {
sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
Expand Down Expand Up @@ -174,7 +175,7 @@ type XTCPVisitor struct {
}

func (sv *XTCPVisitor) Run() (err error) {
sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
Expand Down Expand Up @@ -352,7 +353,7 @@ type SUDPVisitor struct {
func (sv *SUDPVisitor) Run() (err error) {
xl := xlog.FromContextSafe(sv.ctx)

addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return fmt.Errorf("sudp ResolveUDPAddr error: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/util/net/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"

Expand Down Expand Up @@ -163,7 +164,7 @@ type UDPListener struct {
}

func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
if err != nil {
return l, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/util/net/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package net

import (
"errors"
"fmt"
"net"
"net/http"
"strconv"

"golang.org/x/net/websocket"
)
Expand Down Expand Up @@ -52,7 +52,7 @@ func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
}

func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
tcpLn, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/util/tcpmux/httpconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func readHTTPConnectRequest(rd io.Reader) (host string, err error) {
return
}

host = util.GetHostFromAddr(req.Host)
host, _ = util.CanonicalHost(req.Host)
return
}

Expand Down
11 changes: 0 additions & 11 deletions pkg/util/util/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,6 @@ func OkResponse() *http.Response {
return res
}

// TODO: use "CanonicalHost" func to replace all "GetHostFromAddr" func.
func GetHostFromAddr(addr string) (host string) {
strs := strings.Split(addr, ":")
if len(strs) > 1 {
host = strs[0]
} else {
host = addr
}
return
}

// canonicalHost strips port from host if present and returns the canonicalized
// host name.
func CanonicalHost(host string) (string, error) {
Expand Down
3 changes: 2 additions & 1 deletion pkg/util/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
)
Expand Down Expand Up @@ -52,7 +53,7 @@ func CanonicalAddr(host string, port int) (addr string) {
if port == 80 || port == 443 {
addr = host
} else {
addr = fmt.Sprintf("%s:%d", host, port)
addr = net.JoinHostPort(host, strconv.Itoa(port))
}
return
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/util/vhost/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
Director: func(req *http.Request) {
req.URL.Scheme = "http"
url := req.Context().Value(RouteInfoURL).(string)
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
oldHost, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string))
rc := rp.GetRouteConfig(oldHost, url)
if rc != nil {
if rc.RewriteHost != "" {
Expand All @@ -81,7 +81,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
IdleConnTimeout: 60 * time.Second,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
url := ctx.Value(RouteInfoURL).(string)
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
host, _ := util.CanonicalHost(ctx.Value(RouteInfoHost).(string))
remote := ctx.Value(RouteInfoRemote).(string)
return rp.CreateConnection(host, url, remote)
},
Expand Down Expand Up @@ -191,7 +191,7 @@ func (rp *HTTPReverseProxy) getVhost(domain string, location string) (vr *Router
}

func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
domain := util.GetHostFromAddr(req.Host)
domain, _ := util.CanonicalHost(req.Host)
location := req.URL.Path
user, passwd, _ := req.BasicAuth()
if !rp.CheckAuth(domain, location, user, passwd) {
Expand Down
4 changes: 2 additions & 2 deletions server/group/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
package group

import (
"fmt"
"net"
"strconv"
"sync"

"github.com/fatedier/frp/server/ports"
Expand Down Expand Up @@ -101,7 +101,7 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
if err != nil {
return
}
tcpLn, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, port))
tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port)))
if errRet != nil {
err = errRet
return
Expand Down
6 changes: 3 additions & 3 deletions server/ports/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package ports

import (
"errors"
"fmt"
"net"
"strconv"
"sync"
"time"
)
Expand Down Expand Up @@ -134,7 +134,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {

func (pm *Manager) isPortAvailable(port int) bool {
if pm.netType == "udp" {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port))
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
if err != nil {
return false
}
Expand All @@ -146,7 +146,7 @@ func (pm *Manager) isPortAvailable(port int) bool {
return true
}

l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port))
l, err := net.Listen(pm.netType, net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
if err != nil {
return false
}
Expand Down
3 changes: 2 additions & 1 deletion server/proxy/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package proxy
import (
"fmt"
"net"
"strconv"

"github.com/fatedier/frp/pkg/config"
)
Expand Down Expand Up @@ -54,7 +55,7 @@ func (pxy *TCPProxy) Run() (remoteAddr string, err error) {
pxy.rc.TCPPortManager.Release(pxy.realPort)
}
}()
listener, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort))
listener, errRet := net.Listen("tcp", net.JoinHostPort(pxy.serverCfg.ProxyBindAddr, strconv.Itoa(pxy.realPort)))
if errRet != nil {
err = errRet
return
Expand Down
3 changes: 2 additions & 1 deletion server/proxy/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"time"

"github.com/fatedier/frp/pkg/config"
Expand Down Expand Up @@ -70,7 +71,7 @@ func (pxy *UDPProxy) Run() (remoteAddr string, err error) {

remoteAddr = fmt.Sprintf(":%d", pxy.realPort)
pxy.cfg.RemotePort = pxy.realPort
addr, errRet := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pxy.serverCfg.ProxyBindAddr, pxy.realPort))
addr, errRet := net.ResolveUDPAddr("udp", net.JoinHostPort(pxy.serverCfg.ProxyBindAddr, strconv.Itoa(pxy.realPort)))
if errRet != nil {
err = errRet
return
Expand Down
11 changes: 6 additions & 5 deletions server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
// Create tcpmux httpconnect multiplexer.
if cfg.TCPMuxHTTPConnectPort > 0 {
var l net.Listener
l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.ProxyBindAddr, cfg.TCPMuxHTTPConnectPort))
address := net.JoinHostPort(cfg.ProxyBindAddr, strconv.Itoa(cfg.TCPMuxHTTPConnectPort))
l, err = net.Listen("tcp", address)
if err != nil {
err = fmt.Errorf("Create server listener error, %v", err)
return
Expand All @@ -135,7 +136,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
err = fmt.Errorf("Create vhost tcpMuxer error, %v", err)
return
}
log.Info("tcpmux httpconnect multiplexer listen on %s:%d", cfg.ProxyBindAddr, cfg.TCPMuxHTTPConnectPort)
log.Info("tcpmux httpconnect multiplexer listen on %s", address)
}

// Init all plugins
Expand Down Expand Up @@ -199,7 +200,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
err = fmt.Errorf("Listen on kcp address udp %s error: %v", address, err)
return
}
log.Info("frps kcp listen on udp %s:%d", cfg.BindAddr, cfg.KCPBindPort)
log.Info("frps kcp listen on udp %s", address)
}

// Listen for accepting connections from client using websocket protocol.
Expand Down Expand Up @@ -232,7 +233,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
}
}
go server.Serve(l)
log.Info("http service listen on %s:%d", cfg.ProxyBindAddr, cfg.VhostHTTPPort)
log.Info("http service listen on %s", address)
}

// Create https vhost muxer.
Expand Down Expand Up @@ -288,7 +289,7 @@ func NewService(cfg config.ServerCommonConf) (svr *Service, err error) {
err = fmt.Errorf("Create dashboard web server error, %v", err)
return
}
log.Info("Dashboard listen on %s:%d", cfg.DashboardAddr, cfg.DashboardPort)
log.Info("Dashboard listen on %s", address)
statsEnable = true
}
if statsEnable {
Expand Down
20 changes: 20 additions & 0 deletions test/e2e/basic/client_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,24 @@ var _ = Describe("[Feature: Client-Server]", func() {
})
}
})

Describe("IPv6 bind address", func() {
supportProtocols := []string{"tcp", "kcp", "websocket"}
for _, protocol := range supportProtocols {
tmp := protocol
defineClientServerTest("IPv6 bind address: "+strings.ToUpper(tmp), f, &generalTestConfigures{
server: fmt.Sprintf(`
bind_addr = ::
kcp_bind_port = {{ .%s }}
protocol = %s
`, consts.PortServerName, protocol),
client: fmt.Sprintf(`
tls_enable = true
protocol = %s
disable_custom_tls_first_byte = true
`, protocol),
})
}
})

})
3 changes: 1 addition & 2 deletions test/e2e/mock/server/httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package httpserver

import (
"crypto/tls"
"fmt"
"net"
"net/http"
"strconv"
Expand Down Expand Up @@ -97,7 +96,7 @@ func (s *Server) Close() error {
}

func (s *Server) initListener() (err error) {
s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort))
s.l, err = net.Listen("tcp", net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort)))
return
}

Expand Down
3 changes: 2 additions & 1 deletion test/e2e/mock/server/streamserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net"
"strconv"

libnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/test/e2e/pkg/rpc"
Expand Down Expand Up @@ -99,7 +100,7 @@ func (s *Server) Close() error {
func (s *Server) initListener() (err error) {
switch s.netType {
case TCP:
s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort))
s.l, err = net.Listen("tcp", net.JoinHostPort(s.bindAddr, strconv.Itoa(s.bindPort)))
case UDP:
s.l, err = libnet.ListenUDP(s.bindAddr, s.bindPort)
case Unix:
Expand Down
5 changes: 3 additions & 2 deletions test/e2e/pkg/port/port.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package port
import (
"fmt"
"net"
"strconv"
"sync"

"k8s.io/apimachinery/pkg/util/sets"
Expand Down Expand Up @@ -57,15 +58,15 @@ func (pa *Allocator) GetByName(portName string) int {
return 0
}

l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
l, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err != nil {
// Maybe not controlled by us, mark it used.
pa.used.Insert(port)
continue
}
l.Close()

udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", port))
udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err != nil {
continue
}
Expand Down

0 comments on commit 6194273

Please sign in to comment.