Skip to content

Commit

Permalink
修复地址解析问题
Browse files Browse the repository at this point in the history
  • Loading branch information
davyxu committed Nov 6, 2018
1 parent a11c867 commit db4866c
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 43 deletions.
2 changes: 1 addition & 1 deletion examples/websocket/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func main() {
// 创建一个事件处理队列,整个服务器只有这一个队列处理事件,服务器属于单线程服务器
queue := cellnet.NewEventQueue()

p := peer.NewGenericPeer("gorillaws.Acceptor", "server", "http://127.0.0.1:18802/echo", queue)
p := peer.NewGenericPeer("gorillaws.Acceptor", "server", "http://127.0.0.1:18802~18803/echo", queue)

proc.BindProcessorHandler(p, "gorillaws.ltv", func(ev cellnet.Event) {

Expand Down
27 changes: 11 additions & 16 deletions peer/gorillaws/acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ import (
"github.com/davyxu/cellnet/peer"
"github.com/davyxu/cellnet/util"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"net"
"net/http"
"net/url"
)

type wsAcceptor struct {
Expand Down Expand Up @@ -53,21 +51,14 @@ func (self *wsAcceptor) SetHttps(certfile, keyfile string) {

func (self *wsAcceptor) Start() cellnet.Peer {

var addrURL *url.URL
var addrObj *util.Address
var err error
var raw interface{}
raw, err = util.DetectPort(self.Address(), func(s string) (interface{}, error) {
raw, err = util.DetectPort(self.Address(), func(a *util.Address) (interface{}, error) {

addrURL, err = url.Parse(s)
addrObj = a

if err != nil {
return nil, err
}

if addrURL.Path == "" {
return nil, errors.New("expect path in url to listen")
}
return net.Listen("tcp", addrURL.Host)
return net.Listen("tcp", a.HostPort())
})

if err != nil {
Expand All @@ -79,7 +70,11 @@ func (self *wsAcceptor) Start() cellnet.Peer {

mux := http.NewServeMux()

mux.HandleFunc(addrURL.Path, func(w http.ResponseWriter, r *http.Request) {
if addrObj.Path == "" {
addrObj.Path = "/"
}

mux.HandleFunc(addrObj.Path, func(w http.ResponseWriter, r *http.Request) {

c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
Expand All @@ -95,11 +90,11 @@ func (self *wsAcceptor) Start() cellnet.Peer {

})

self.sv = &http.Server{Addr: addrURL.Host, Handler: mux}
self.sv = &http.Server{Addr: addrObj.HostPort(), Handler: mux}

go func() {

log.Infof("#websocket.listen(%s) %s", self.Name(), addrURL.String())
log.Infof("#websocket.listen(%s) %s", self.Name(), addrObj.String())

if self.certfile != "" && self.keyfile != "" {
err = self.sv.ServeTLS(self.listener, self.certfile, self.keyfile)
Expand Down
4 changes: 2 additions & 2 deletions peer/http/acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ func (self *httpAcceptor) Start() cellnet.Peer {

self.sv = &http.Server{Addr: self.Address(), Handler: self}

ln, err := util.DetectPort(self.Address(), func(s string) (interface{}, error) {
return net.Listen("tcp", s)
ln, err := util.DetectPort(self.Address(), func(a *util.Address) (interface{}, error) {
return net.Listen("tcp", a.HostPort())
})

self.listener = ln.(net.Listener)
Expand Down
4 changes: 2 additions & 2 deletions peer/tcp/acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func (self *tcpAcceptor) Start() cellnet.Peer {
return self
}

ln, err := util.DetectPort(self.Address(), func(s string) (interface{}, error) {
return net.Listen("tcp", s)
ln, err := util.DetectPort(self.Address(), func(a *util.Address) (interface{}, error) {
return net.Listen("tcp", a.HostPort())
})

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions peer/udp/acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ func (self *udpAcceptor) Start() cellnet.Peer {
self.mtTotalRecvUDPPacket = expvar.NewInt(fmt.Sprintf("cellnet.Peer(%s).TotalRecvUDPPacket", self.Name()))
}

ln, err := util.DetectPort(self.Address(), func(s string) (interface{}, error) {
ln, err := util.DetectPort(self.Address(), func(a *util.Address) (interface{}, error) {

addr, err := net.ResolveUDPAddr("udp", s)
addr, err := net.ResolveUDPAddr("udp", a.HostPort())
if err != nil {
return nil, err
}
Expand Down
93 changes: 73 additions & 20 deletions util/addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,95 @@ var (
ErrInvalidPortRange = errors.New("invalid port range")
)

// 在给定的端口范围内找到一个能用的端口 格式: localhost:5000~6000
func DetectPort(addr string, fn func(string) (interface{}, error)) (interface{}, error) {
// host:port 或 host:min~max
parts := strings.Split(addr, ":")
type Address struct {
Scheme string
Host string
Port int
Path string
}

// host:port格式
if len(parts) < 2 {
return fn(addr)
func (self *Address) String() string {
if self.Scheme == "" {
return fmt.Sprintf("%s:%d", self.Host, self.Port)
}

// 间隔分割
ports := strings.Split(parts[len(parts)-1], "~")
return fmt.Sprintf("%s://%s:%d%s", self.Scheme, self.Host, self.Port, self.Path)
}

func (self *Address) HostPort() string {

return fmt.Sprintf("%s:%d", self.Host, self.Port)
}

// 在给定的端口范围内找到一个能用的端口 格式:
// scheme://host:minPort~maxPort/path
func DetectPort(addr string, fn func(*Address) (interface{}, error)) (interface{}, error) {

// 单独的端口
if len(ports) < 2 {
return fn(addr)
var addrObj Address
schemePos := strings.Index(addr, "://")

// 移除scheme部分
if schemePos != -1 {
addrObj.Scheme = addr[:schemePos]
addr = addr[schemePos+3:]
}

// extract min port
min, err := strconv.Atoi(ports[0])
if err != nil {
return nil, ErrInvalidPortRange
colonPos := strings.Index(addr, ":")

if colonPos != -1 {
addrObj.Host = addr[:colonPos]
}

addr = addr[colonPos+1:]

rangePos := strings.Index(addr, "~")

var minStr, maxStr string
if rangePos != -1 {
minStr = addr[:rangePos]

slashPos := strings.Index(addr, "/")

if slashPos != -1 {
maxStr = addr[rangePos+1 : slashPos]
addrObj.Path = addr[slashPos:]
} else {
maxStr = addr[rangePos:]
}
} else {
slashPos := strings.Index(addr, "/")

if slashPos != -1 {
addrObj.Path = addr[slashPos:]
minStr = addr[rangePos+1 : slashPos]
} else {
minStr = addr[rangePos+1:]
}
}

// extract max port
max, err := strconv.Atoi(ports[1])
// extract min port
min, err := strconv.Atoi(minStr)
if err != nil {
return nil, ErrInvalidPortRange
}

host := parts[0]
var max int
if maxStr != "" {
// extract max port
max, err = strconv.Atoi(maxStr)
if err != nil {
return nil, ErrInvalidPortRange
}
} else {
max = min
}

for port := min; port <= max; port++ {

addrObj.Port = port

// 使用回调侦听
ln, err := fn(fmt.Sprintf("%s:%d", host, port))
ln, err := fn(&addrObj)
if err == nil {
return ln, nil
}
Expand Down
16 changes: 16 additions & 0 deletions util/addr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package util

import (
"errors"
"testing"
)

func TestDetectPort(t *testing.T) {
DetectPort("scheme://host:100~200/path", func(s string) (interface{}, error) {
if s != "host:100" {
t.FailNow()
}

return nil, errors.New("err")
})
}

0 comments on commit db4866c

Please sign in to comment.