From 2f2331a50cd3d06c5be273dcbe5f8530c1b56214 Mon Sep 17 00:00:00 2001 From: Ehco Date: Fri, 16 Aug 2024 14:36:54 +0800 Subject: [PATCH] feat: support relay udp (#361) --- examples/block.json | 36 --- examples/cf-ws.json | 15 -- examples/config.json | 62 ++--- examples/mptcp.json | 23 -- examples/mws.json | 24 -- examples/speed_limit.json | 14 - examples/sub.json | 20 -- examples/web.json | 16 -- examples/with_ping.json | 29 --- examples/ws_config.json | 24 -- internal/cli/config.go | 14 - internal/cli/flags.go | 9 +- internal/cmgr/cmgr.go | 3 +- internal/config/config.go | 14 +- internal/conn/relay_conn.go | 340 ++++++++++++++----------- internal/conn/relay_conn_test.go | 21 +- internal/conn/udp_listener.go | 187 ++++++++++++++ internal/conn/ws_conn.go | 3 +- internal/constant/constant.go | 25 +- internal/relay/conf/cfg.go | 89 +++---- internal/relay/relay.go | 10 +- internal/relay/server.go | 6 +- internal/relay/server_reloader.go | 6 +- internal/transporter/base.go | 189 +++++++++----- internal/transporter/interface.go | 25 +- internal/transporter/mux.go | 201 --------------- internal/transporter/raw.go | 101 +++++--- internal/transporter/raw_mux.go | 109 -------- internal/transporter/ws.go | 110 ++++---- internal/transporter/ws_mux.go | 122 --------- internal/transporter/wss.go | 18 +- internal/transporter/wss_mux.go | 122 --------- internal/web/templates/connection.html | 4 +- pkg/buffer/buffer.go | 4 +- pkg/metric_reader/reader.go | 2 +- pkg/sub/clash.go | 4 +- pkg/sub/clash_types.go | 4 +- test/relay_test.go | 178 +++++-------- 38 files changed, 835 insertions(+), 1348 deletions(-) delete mode 100644 examples/block.json delete mode 100644 examples/cf-ws.json delete mode 100644 examples/mptcp.json delete mode 100644 examples/mws.json delete mode 100644 examples/speed_limit.json delete mode 100644 examples/sub.json delete mode 100644 examples/web.json delete mode 100644 examples/with_ping.json delete mode 100644 examples/ws_config.json create mode 100644 internal/conn/udp_listener.go delete mode 100644 internal/transporter/mux.go delete mode 100644 internal/transporter/raw_mux.go delete mode 100644 internal/transporter/ws_mux.go delete mode 100644 internal/transporter/wss_mux.go diff --git a/examples/block.json b/examples/block.json deleted file mode 100644 index 8de54e750..000000000 --- a/examples/block.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "log_level": "debug", - "relay_configs": [ - { - "label": "relay-to-http-success", - "listen": "127.0.0.1:1234", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": ["google.com:80"], - "blocked_protocols": ["tls"] - }, - { - "label": "relay-to-http-fail", - "listen": "127.0.0.1:1235", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": ["google.com:80"], - "blocked_protocols": ["http"] - }, - { - "label": "relay-to-tls-success", - "listen": "127.0.0.1:1236", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": ["google.com:443"] - }, - { - "label": "relay-to-tls-fail", - "listen": "127.0.0.1:1237", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": ["google.com:443"], - "blocked_protocols": ["tls"] - } - ] -} diff --git a/examples/cf-ws.json b/examples/cf-ws.json deleted file mode 100644 index 0294138d5..000000000 --- a/examples/cf-ws.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "web_port": 9000, - "relay_configs": [ - { - "listen": "127.0.0.1:1235", - "listen_type": "raw", - "transport_type": "ws", - "tcp_remotes": ["ws://0.0.0.0:8787"], - "ws_config": { - "path": "pwd", - "remote_addr": "127.0.0.1:5201" - } - } - ] -} diff --git a/examples/config.json b/examples/config.json index e538aa9bd..6d1926529 100644 --- a/examples/config.json +++ b/examples/config.json @@ -9,71 +9,41 @@ "listen_type": "raw", "transport_type": "raw", "label": "relay1", - "tcp_remotes": ["0.0.0.0:5201"], - "udp_remotes": ["0.0.0.0:5201"] + "tcp_remotes": [ + "0.0.0.0:5201" + ] }, { "listen": "127.0.0.1:1235", "listen_type": "raw", "transport_type": "ws", - "tcp_remotes": ["ws://0.0.0.0:2443"], - "udp_remotes": ["0.0.0.0:5201"] + "tcp_remotes": [ + "ws://0.0.0.0:2443" + ] }, { "listen": "127.0.0.1:1236", "listen_type": "raw", "transport_type": "wss", - "tcp_remotes": ["wss://0.0.0.0:3443"], - "udp_remotes": ["0.0.0.0:5201"] - }, - { - "listen": "127.0.0.1:1237", - "listen_type": "raw", - "transport_type": "mwss", - "tcp_remotes": ["wss://0.0.0.0:4443"], - "udp_remotes": ["0.0.0.0:5201"] - }, - { - "listen": "127.0.0.1:1238", - "listen_type": "raw", - "transport_type": "mtcp", - "tcp_remotes": ["0.0.0.0:4444"], - "udp_remotes": ["0.0.0.0:5201"] + "tcp_remotes": [ + "wss://0.0.0.0:3443" + ] }, { "listen": "127.0.0.1:2443", "listen_type": "ws", "transport_type": "raw", - "tcp_remotes": ["0.0.0.0:5201"], - "udp_remotes": [] + "tcp_remotes": [ + "0.0.0.0:5201" + ] }, { "listen": "127.0.0.1:3443", "listen_type": "wss", "transport_type": "raw", - "tcp_remotes": ["0.0.0.0:5201"], - "udp_remotes": [] - }, - { - "listen": "127.0.0.1:4443", - "listen_type": "mwss", - "transport_type": "raw", - "tcp_remotes": ["0.0.0.0:5201"], - "udp_remotes": [] - }, - { - "listen": "127.0.0.1:4444", - "listen_type": "mtcp", - "transport_type": "raw", - "tcp_remotes": ["0.0.0.0:5201"], - "udp_remotes": [] - }, - { - "label": "ping_test", - "listen": "127.0.0.1:8888", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": ["8.8.8.8:5201", "google.com:5201"] + "tcp_remotes": [ + "0.0.0.0:5201" + ] } ] -} +} \ No newline at end of file diff --git a/examples/mptcp.json b/examples/mptcp.json deleted file mode 100644 index 1f74f0f5f..000000000 --- a/examples/mptcp.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "log_level": "info", - "relay_configs": [ - { - "label": "client", - "listen": "127.0.0.1:1234", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": [ - "0.0.0.0:1235" - ] - }, - { - "label": "server", - "listen": "127.0.0.1:1235", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": [ - "0.0.0.0:5201" - ] - } - ] -} \ No newline at end of file diff --git a/examples/mws.json b/examples/mws.json deleted file mode 100644 index 62aff74a0..000000000 --- a/examples/mws.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "relay_configs": [ - { - "listen": "127.0.0.1:1235", - "listen_type": "raw", - "transport_type": "mws", - "tcp_remotes": ["ws://0.0.0.0:2443"], - "ws_config": { - "path": "pwd", - "remote_addr": "127.0.0.1:5201" - } - }, - { - "listen": "127.0.0.1:2443", - "listen_type": "mws", - "transport_type": "raw", - "tcp_remotes": ["0.0.0.0:5201"], - "ws_config": { - "path": "pwd", - "remote_addr": "127.0.0.1:5201" - } - } - ] -} diff --git a/examples/speed_limit.json b/examples/speed_limit.json deleted file mode 100644 index 7dd587d54..000000000 --- a/examples/speed_limit.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "relay_configs": [ - { - "listen": "127.0.0.1:1234", - "listen_type": "raw", - "transport_type": "raw", - "label": "iperf3", - "tcp_remotes": [ - "0.0.0.0:5201" - ], - "max_read_rate_kbps": 10000 - } - ] -} \ No newline at end of file diff --git a/examples/sub.json b/examples/sub.json deleted file mode 100644 index 940c77932..000000000 --- a/examples/sub.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "web_port": 9000, - "log_level": "debug", - "reload_interval": 10, - "relay_configs": [ - { - "listen": "127.0.0.1:1234", - "listen_type": "raw", - "transport_type": "raw", - "tcp_remotes": ["0.0.0.0:5201"], - "udp_remotes": ["0.0.0.0:5201"] - } - ], - "sub_configs": [ - { - "name": "sub1", - "url": "xxx" - } - ] -} diff --git a/examples/web.json b/examples/web.json deleted file mode 100644 index 2a52dfb47..000000000 --- a/examples/web.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "web_port": 9000, - "web_token": "", - "web_auth_user": "user", - "web_auth_pass": "pass", - "log_level": "info", - "relay_configs": [ - { - "listen": "127.0.0.1:1234", - "listen_type": "raw", - "transport_type": "raw", - "label": "iperf3", - "tcp_remotes": ["0.0.0.0:5201"] - } - ] -} diff --git a/examples/with_ping.json b/examples/with_ping.json deleted file mode 100644 index 78c26e8b3..000000000 --- a/examples/with_ping.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "web_port": 9000, - "log_level": "info", - "enable_ping": true, - "relay_sync_interval": 6, - "relay_configs": [ - { - "listen": "127.0.0.1:1234", - "listen_type": "raw", - "transport_type": "raw", - "label": "raw", - "tcp_remotes": ["192.168.31.30:5201"] - }, - { - "listen": "127.0.0.1:1235", - "listen_type": "raw", - "transport_type": "ws", - "label": "ws", - "tcp_remotes": ["ws://192.168.31.30:2443"] - }, - { - "listen": "127.0.0.1:1236", - "listen_type": "raw", - "transport_type": "wss", - "label": "wss", - "tcp_remotes": ["wss://192.168.31.31:2443"] - } - ] -} diff --git a/examples/ws_config.json b/examples/ws_config.json deleted file mode 100644 index 58172c0dc..000000000 --- a/examples/ws_config.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "relay_configs": [ - { - "listen": "127.0.0.1:1235", - "listen_type": "raw", - "transport_type": "ws", - "tcp_remotes": ["ws://0.0.0.0:2443"], - "ws_config": { - "path": "pwd", - "remote_addr": "127.0.0.1:5201" - } - }, - { - "listen": "127.0.0.1:2443", - "listen_type": "ws", - "transport_type": "raw", - "tcp_remotes": ["0.0.0.0:5201"], - "ws_config": { - "path": "pwd", - "remote_addr": "127.0.0.1:5201" - } - } - ] -} diff --git a/internal/cli/config.go b/internal/cli/config.go index f23fba816..f5b025bea 100644 --- a/internal/cli/config.go +++ b/internal/cli/config.go @@ -9,7 +9,6 @@ import ( "github.com/Ehco1996/ehco/internal/metrics" "github.com/Ehco1996/ehco/internal/relay" "github.com/Ehco1996/ehco/internal/relay/conf" - "github.com/Ehco1996/ehco/internal/tls" "github.com/Ehco1996/ehco/internal/web" "github.com/Ehco1996/ehco/pkg/buffer" "github.com/Ehco1996/ehco/pkg/log" @@ -42,24 +41,11 @@ func loadConfig() (cfg *config.Config, err error) { if TCPRemoteAddr != "" { cfg.RelayConfigs[0].TCPRemotes = []string{TCPRemoteAddr} } - if UDPRemoteAddr != "" { - cfg.RelayConfigs[0].UDPRemotes = []string{UDPRemoteAddr} - } if err := cfg.Adjust(); err != nil { return nil, err } } - // init tls when need - for _, cfg := range cfg.RelayConfigs { - if cfg.ListenType == constant.RelayTypeWSS || cfg.ListenType == constant.RelayTypeMWSS || - cfg.TransportType == constant.RelayTypeWSS || cfg.TransportType == constant.RelayTypeMWSS { - if err := tls.InitTlsCfg(); err != nil { - return nil, err - } - break - } - } return cfg, nil } diff --git a/internal/cli/flags.go b/internal/cli/flags.go index 05ce342b2..b78a2b182 100644 --- a/internal/cli/flags.go +++ b/internal/cli/flags.go @@ -10,7 +10,6 @@ var ( LocalAddr string ListenType constant.RelayType TCPRemoteAddr string - UDPRemoteAddr string TransportType constant.RelayType ConfigPath string WebPort int @@ -39,16 +38,10 @@ var RootFlags = []cli.Flag{ }, &cli.StringFlag{ Name: "r,remote", - Usage: "TCP 转发地址,例如 0.0.0.0:5201,通过 ws 隧道转发时应为 ws://0.0.0.0:2443", + Usage: "转发地址,例如 0.0.0.0:5201,通过 ws 隧道转发时应为 ws://0.0.0.0:2443", EnvVars: []string{"EHCO_REMOTE_ADDR"}, Destination: &TCPRemoteAddr, }, - &cli.StringFlag{ - Name: "ur,udp_remote", - Usage: "UDP 转发地址,例如 0.0.0.0:1234", - EnvVars: []string{"EHCO_UDP_REMOTE_ADDR"}, - Destination: &UDPRemoteAddr, - }, &cli.StringFlag{ Name: "tt,transport_type", Value: "raw", diff --git a/internal/cmgr/cmgr.go b/internal/cmgr/cmgr.go index 560ac509a..18b05f2fa 100644 --- a/internal/cmgr/cmgr.go +++ b/internal/cmgr/cmgr.go @@ -16,7 +16,8 @@ const ( ConnectionTypeClosed = "closed" ) -// connection manager interface +// connection manager interface/ +// TODO support closed connection type Cmgr interface { ListConnections(connType string, page, pageSize int) []conn.RelayConn diff --git a/internal/config/config.go b/internal/config/config.go index 2cb14ac97..0de55fa2c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,9 +7,10 @@ import ( "strings" "time" - myhttp "github.com/Ehco1996/ehco/pkg/http" - + "github.com/Ehco1996/ehco/internal/constant" "github.com/Ehco1996/ehco/internal/relay/conf" + "github.com/Ehco1996/ehco/internal/tls" + myhttp "github.com/Ehco1996/ehco/pkg/http" "github.com/Ehco1996/ehco/pkg/sub" xConf "github.com/xtls/xray-core/infra/conf" "go.uber.org/zap" @@ -121,6 +122,15 @@ func (c *Config) Adjust() error { } labelMap[r.Label] = struct{}{} } + // init tls when need + for _, r := range c.RelayConfigs { + if r.ListenType == constant.RelayTypeWSS { + if err := tls.InitTlsCfg(); err != nil { + return err + } + break + } + } return nil } diff --git a/internal/conn/relay_conn.go b/internal/conn/relay_conn.go index f8b996218..74d1bd359 100644 --- a/internal/conn/relay_conn.go +++ b/internal/conn/relay_conn.go @@ -3,18 +3,160 @@ package conn import ( "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "net" "time" + "github.com/Ehco1996/ehco/internal/constant" + "github.com/Ehco1996/ehco/internal/lb" "github.com/Ehco1996/ehco/internal/metrics" "github.com/Ehco1996/ehco/pkg/buffer" "github.com/Ehco1996/ehco/pkg/bytes" "go.uber.org/zap" ) -var idleTimeout = 30 * time.Second +// RelayConn is the interface that represents a relay connection. +// it contains two connections: clientConn and remoteConn +// clientConn is the connection from the client to the relay server +// remoteConn is the connection from the relay server to the remote server +// and the main function is to transport data between the two connections +type RelayConn interface { + // Transport transports data between the client and the remote connection. + Transport() error + GetRelayLabel() string + GetStats() *Stats + Close() error +} + +type RelayConnOption func(*relayConnImpl) + +func NewRelayConn(clientConn, remoteConn net.Conn, opts ...RelayConnOption) RelayConn { + rci := &relayConnImpl{ + clientConn: clientConn, + remoteConn: remoteConn, + Stats: &Stats{}, + } + for _, opt := range opts { + opt(rci) + } + if rci.l == nil { + rci.l = zap.S().Named(rci.RelayLabel) + } + return rci +} + +type relayConnImpl struct { + clientConn net.Conn + remoteConn net.Conn + + Closed bool `json:"closed"` + + Stats *Stats `json:"stats"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time,omitempty"` + + // options set those fields + l *zap.SugaredLogger + remote *lb.Node + HandshakeDuration time.Duration + RelayLabel string `json:"relay_label"` + ConnType string `json:"conn_type"` +} + +func WithRelayLabel(relayLabel string) RelayConnOption { + return func(rci *relayConnImpl) { + rci.RelayLabel = relayLabel + } +} + +func WithHandshakeDuration(duration time.Duration) RelayConnOption { + return func(rci *relayConnImpl) { + rci.HandshakeDuration = duration + } +} + +func WithConnType(connType string) RelayConnOption { + return func(rci *relayConnImpl) { + rci.ConnType = connType + } +} + +func WithRemote(remote *lb.Node) RelayConnOption { + return func(rci *relayConnImpl) { + rci.remote = remote + } +} + +func WithLogger(l *zap.SugaredLogger) RelayConnOption { + return func(rci *relayConnImpl) { + rci.l = l + } +} + +func (rc *relayConnImpl) Transport() error { + defer rc.Close() // nolint: errcheck + cl := rc.l.Named(shortHashSHA256(rc.GetFlow())) + cl.Debugf("transport start, stats: %s", rc.Stats.String()) + c1 := newInnerConn(rc.clientConn, rc) + c2 := newInnerConn(rc.remoteConn, rc) + rc.StartTime = time.Now().Local() + err := copyConn(c1, c2) + if err != nil { + cl.Errorf("transport error: %s", err.Error()) + } + cl.Debugf("transport end, stats: %s", rc.Stats.String()) + rc.EndTime = time.Now().Local() + return err +} + +func (rc *relayConnImpl) Close() error { + err1 := rc.clientConn.Close() + err2 := rc.remoteConn.Close() + rc.Closed = true + return combineErrorsAndMuteEOF(err1, err2) +} + +// functions that for web ui +func (rc *relayConnImpl) GetTime() string { + if rc.EndTime.IsZero() { + return fmt.Sprintf("%s - N/A", rc.StartTime.Format(time.Stamp)) + } + return fmt.Sprintf("%s - %s", rc.StartTime.Format(time.Stamp), rc.EndTime.Format(time.Stamp)) +} + +func (rc *relayConnImpl) GetFlow() string { + return fmt.Sprintf("%s <-> %s", rc.clientConn.RemoteAddr(), rc.remoteConn.RemoteAddr()) +} + +func (rc *relayConnImpl) GetRelayLabel() string { + return rc.RelayLabel +} + +func (rc *relayConnImpl) GetStats() *Stats { + return rc.Stats +} + +func (rc *relayConnImpl) GetConnType() string { + return rc.ConnType +} + +func combineErrorsAndMuteEOF(err1, err2 error) error { + if err1 == io.EOF { + err1 = nil + } + if err2 == io.EOF { + return nil + } + if err1 != nil && err2 != nil { + return errors.Join(err1, err2) + } + if err1 != nil { + return err1 + } + return err2 +} type Stats struct { Up int64 @@ -35,52 +177,69 @@ func (s *Stats) String() string { ) } +// note that innerConn is a wrapper around net.Conn to allow io.Copy to be used type innerConn struct { net.Conn + lastActive time.Time - remoteLabel string - stats *Stats + rc *relayConnImpl } -func (c *innerConn) setDeadline(isRead bool) { - // set the read deadline to avoid hanging read for non-TCP connections - // because tcp connections have closeWrite/closeRead so no need to set read deadline - if _, ok := c.Conn.(*net.TCPConn); !ok { - deadline := time.Now().Add(idleTimeout) - if isRead { - _ = c.Conn.SetReadDeadline(deadline) - } else { - _ = c.Conn.SetWriteDeadline(deadline) - } - } +func newInnerConn(conn net.Conn, rc *relayConnImpl) *innerConn { + return &innerConn{Conn: conn, rc: rc, lastActive: time.Now()} } func (c *innerConn) recordStats(n int, isRead bool) { + if c.rc == nil { + return + } if isRead { metrics.NetWorkTransmitBytes.WithLabelValues( - c.remoteLabel, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_READ, + c.rc.remote.Label, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_READ, ).Add(float64(n)) - c.stats.Record(0, int64(n)) + c.rc.Stats.Record(0, int64(n)) } else { metrics.NetWorkTransmitBytes.WithLabelValues( - c.remoteLabel, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_WRITE, + c.rc.remote.Label, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_WRITE, ).Add(float64(n)) - c.stats.Record(int64(n), 0) + c.rc.Stats.Record(int64(n), 0) } } -// 修改Read和Write方法以使用recordStats func (c *innerConn) Read(p []byte) (n int, err error) { - c.setDeadline(true) - n, err = c.Conn.Read(p) - c.recordStats(n, true) // true for read operation - return + for { + deadline := time.Now().Add(constant.ReadTimeOut) + if err := c.Conn.SetReadDeadline(deadline); err != nil { + return 0, err + } + n, err = c.Conn.Read(p) + if err == nil { + c.recordStats(n, true) + c.lastActive = time.Now() + return + } else { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + if time.Since(c.lastActive) > constant.IdleTimeOut { + c.rc.l.Debugf("read idle,close remote: %s", c.rc.remote.Label) + return 0, io.EOF + } + continue + } + return n, err + } + } } func (c *innerConn) Write(p []byte) (n int, err error) { - c.setDeadline(false) + if time.Since(c.lastActive) > constant.IdleTimeOut { + c.rc.l.Debugf("write idle,close remote: %s", c.rc.remote.Label) + return 0, io.EOF + } n, err = c.Conn.Write(p) c.recordStats(n, false) // false for write operation + if err != nil { + c.lastActive = time.Now() + } return } @@ -109,10 +268,6 @@ func shortHashSHA256(input string) string { return hex.EncodeToString(hash)[:7] } -func connectionName(conn net.Conn) string { - return fmt.Sprintf("l:<%s> r:<%s>", conn.LocalAddr(), conn.RemoteAddr()) -} - func copyConn(conn1, conn2 *innerConn) error { buf := buffer.BufferPool.Get() defer buffer.BufferPool.Put(buf) @@ -125,134 +280,13 @@ func copyConn(conn1, conn2 *innerConn) error { }() // reverse copy conn2 to conn1,read from conn2 and write to conn1 - _, err := io.Copy(conn1, conn2) + buf2 := buffer.BufferPool.Get() + defer buffer.BufferPool.Put(buf2) + _, err := io.CopyBuffer(conn1, conn2, buf2) _ = conn1.CloseWrite() err2 := <-errCH _ = conn1.CloseRead() _ = conn2.CloseRead() - - // handle errors, need to combine errors from both directions - if err != nil && err2 != nil { - err = fmt.Errorf("transport errors in both directions: %v, %v", err, err2) - } - if err != nil { - return err - } - return err2 -} - -type RelayConnOption func(*relayConnImpl) - -type RelayConn interface { - // Transport transports data between the client and the remote server. - // The remoteLabel is the label of the remote server. - Transport(remoteLabel string) error - - // GetRelayLabel returns the label of the Relay instance. - GetRelayLabel() string - - GetStats() *Stats - - Close() error -} - -func NewRelayConn(relayName string, clientConn, remoteConn net.Conn, opts ...RelayConnOption) RelayConn { - rci := &relayConnImpl{ - RelayLabel: relayName, - clientConn: clientConn, - remoteConn: remoteConn, - } - for _, opt := range opts { - opt(rci) - } - s := &Stats{Up: 0, Down: 0, HandShakeLatency: rci.HandshakeDuration} - rci.Stats = s - return rci -} - -type relayConnImpl struct { - RelayLabel string `json:"relay_label"` - Closed bool `json:"closed"` - - StartTime time.Time `json:"start_time"` - EndTime time.Time `json:"end_time,omitempty"` - - Stats *Stats `json:"stats"` - HandshakeDuration time.Duration - - clientConn net.Conn - remoteConn net.Conn -} - -func WithHandshakeDuration(duration time.Duration) RelayConnOption { - return func(rci *relayConnImpl) { - rci.HandshakeDuration = duration - } -} - -func (rc *relayConnImpl) Transport(remoteLabel string) error { - defer rc.Close() // nolint: errcheck - name := rc.Name() - shortName := fmt.Sprintf("%s-%s", rc.RelayLabel, shortHashSHA256(name)) - cl := zap.L().Named(shortName) - cl.Debug("transport start", zap.String("full name", name), zap.String("stats", rc.Stats.String())) - c1 := &innerConn{ - stats: rc.Stats, - remoteLabel: remoteLabel, - Conn: rc.clientConn, - } - c2 := &innerConn{ - stats: rc.Stats, - remoteLabel: remoteLabel, - Conn: rc.remoteConn, - } - rc.StartTime = time.Now().Local() - err := copyConn(c1, c2) - if err != nil { - cl.Error("transport error", zap.Error(err)) - } - cl.Debug("transport end", zap.String("stats", rc.Stats.String())) - rc.EndTime = time.Now().Local() - return err -} - -func (rc *relayConnImpl) GetTime() string { - if rc.EndTime.IsZero() { - return fmt.Sprintf("%s - N/A", rc.StartTime.Format(time.Stamp)) - } - return fmt.Sprintf("%s - %s", rc.StartTime.Format(time.Stamp), rc.EndTime.Format(time.Stamp)) -} - -func (rc *relayConnImpl) Name() string { - return fmt.Sprintf("c1:[%s] c2:[%s]", connectionName(rc.clientConn), connectionName(rc.remoteConn)) -} - -func (rc *relayConnImpl) Flow() string { - return fmt.Sprintf("%s <-> %s", rc.clientConn.RemoteAddr(), rc.remoteConn.RemoteAddr()) -} - -func (rc *relayConnImpl) GetRelayLabel() string { - return rc.RelayLabel -} - -func (rc *relayConnImpl) GetStats() *Stats { - return rc.Stats -} - -func (rc *relayConnImpl) Close() error { - err1 := rc.clientConn.Close() - err2 := rc.remoteConn.Close() - rc.Closed = true - return combineErrors(err1, err2) -} - -func combineErrors(err1, err2 error) error { - if err1 != nil && err2 != nil { - return fmt.Errorf("combineErrors: %v, %v", err1, err2) - } - if err1 != nil { - return err1 - } - return err2 + return combineErrorsAndMuteEOF(err, err2) } diff --git a/internal/conn/relay_conn_test.go b/internal/conn/relay_conn_test.go index a83ccef7a..c76b417c0 100644 --- a/internal/conn/relay_conn_test.go +++ b/internal/conn/relay_conn_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/Ehco1996/ehco/internal/lb" "github.com/stretchr/testify/assert" ) @@ -18,11 +19,9 @@ func TestInnerConn_ReadWrite(t *testing.T) { serverConn.SetDeadline(time.Now().Add(1 * time.Second)) defer clientConn.Close() defer serverConn.Close() - - innerC := &innerConn{Conn: clientConn, stats: &Stats{}, remoteLabel: "test"} - + rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}} + innerC := newInnerConn(clientConn, &rc) errChan := make(chan error, 1) - go func() { _, err := innerC.Write(testData) errChan <- err @@ -39,7 +38,7 @@ func TestInnerConn_ReadWrite(t *testing.T) { if err := <-errChan; err != nil { t.Fatalf("write err: %v", err) } - assert.Equal(t, int64(len(testData)), innerC.stats.Up) + assert.Equal(t, int64(len(testData)), rc.Stats.Up) errChan = make(chan error, 1) clientConn.SetDeadline(time.Now().Add(1 * time.Second)) @@ -64,7 +63,7 @@ func TestInnerConn_ReadWrite(t *testing.T) { if err := <-errChan; err != nil { t.Fatalf("write error: %v", err) } - assert.Equal(t, int64(len(testData)), innerC.stats.Down) + assert.Equal(t, int64(len(testData)), rc.Stats.Down) } func TestCopyTCPConn(t *testing.T) { @@ -96,8 +95,9 @@ func TestCopyTCPConn(t *testing.T) { assert.NoError(t, err) defer remoteConn.Close() - c1 := &innerConn{Conn: clientConn, remoteLabel: "client", stats: &Stats{}} - c2 := &innerConn{Conn: remoteConn, remoteLabel: "server", stats: &Stats{}} + rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}} + c1 := newInnerConn(clientConn, &rc) + c2 := newInnerConn(remoteConn, &rc) done := make(chan struct{}) go func() { @@ -155,8 +155,9 @@ func TestCopyUDPConn(t *testing.T) { assert.NoError(t, err) defer remoteConn.Close() - c1 := &innerConn{Conn: clientConn, remoteLabel: "client", stats: &Stats{}} - c2 := &innerConn{Conn: remoteConn, remoteLabel: "server", stats: &Stats{}} + rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}} + c1 := newInnerConn(clientConn, &rc) + c2 := newInnerConn(remoteConn, &rc) done := make(chan struct{}) go func() { diff --git a/internal/conn/udp_listener.go b/internal/conn/udp_listener.go new file mode 100644 index 000000000..0385e0e34 --- /dev/null +++ b/internal/conn/udp_listener.go @@ -0,0 +1,187 @@ +//nolint:errcheck +package conn + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/Ehco1996/ehco/internal/constant" + "github.com/Ehco1996/ehco/pkg/buffer" +) + +var _ net.Conn = &uc{} + +type uc struct { + conn *net.UDPConn + addr *net.UDPAddr + + msgCh chan []byte + + lastActivity atomic.Value + + listener *UDPListener +} + +func (c *uc) Read(b []byte) (int, error) { + select { + case msg := <-c.msgCh: + n := copy(b, msg) + c.lastActivity.Store(time.Now()) + return n, nil + default: + if time.Since(c.lastActivity.Load().(time.Time)) > constant.IdleTimeOut { + return 0, io.EOF + } + return 0, nil + } +} + +func (c *uc) Write(b []byte) (int, error) { + n, err := c.conn.WriteToUDP(b, c.addr) + c.lastActivity.Store(time.Now()) + return n, err +} + +func (c *uc) Close() error { + c.listener.connsMu.Lock() + delete(c.listener.conns, c.addr.String()) + c.listener.connsMu.Unlock() + close(c.msgCh) + return nil +} + +func (c *uc) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *uc) RemoteAddr() net.Addr { + return c.addr +} + +func (c *uc) SetDeadline(t time.Time) error { + return nil +} + +func (c *uc) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *uc) SetWriteDeadline(t time.Time) error { + return nil +} + +type UDPListener struct { + listenAddr *net.UDPAddr + listenConn *net.UDPConn + + conns map[string]*uc + connsMu sync.RWMutex + connCh chan *uc + msgCh chan []byte + errCh chan error + + ctx context.Context + cancel context.CancelFunc + + closed atomic.Bool +} + +func NewUDPListener(ctx context.Context, addr string) (*UDPListener, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(ctx) + + l := &UDPListener{ + listenConn: conn, + listenAddr: udpAddr, + + conns: make(map[string]*uc), + connCh: make(chan *uc), + msgCh: make(chan []byte), + errCh: make(chan error), + ctx: ctx, + cancel: cancel, + } + + go l.listen() + + return l, nil +} + +func (l *UDPListener) listen() { + defer l.listenConn.Close() + for { + if l.closed.Load() { + return + } + + buf := buffer.UDPBufferPool.Get() + n, addr, err := l.listenConn.ReadFromUDP(buf) + if err != nil { + if !l.closed.Load() { + select { + case l.errCh <- err: + default: + } + } + buffer.UDPBufferPool.Put(buf) + continue + } + + l.connsMu.RLock() + udpConn, exists := l.conns[addr.String()] + l.connsMu.RUnlock() + if !exists { + l.connsMu.Lock() + udpConn = &uc{ + conn: l.listenConn, + addr: addr, + listener: l, + msgCh: make(chan []byte, 10), + lastActivity: atomic.Value{}, + } + udpConn.lastActivity.Store(time.Now()) + l.conns[addr.String()] = udpConn + l.connCh <- udpConn + l.connsMu.Unlock() + } + + select { + case udpConn.msgCh <- buf[:n]: + default: + buffer.UDPBufferPool.Put(buf) + } + } +} + +func (l *UDPListener) Accept() (*uc, error) { + select { + case conn := <-l.connCh: + return conn, nil + case err := <-l.errCh: + return nil, err + case <-l.ctx.Done(): + return nil, l.ctx.Err() + } +} + +func (l *UDPListener) Close() error { + if !l.closed.CompareAndSwap(false, true) { + return nil + } + l.cancel() + l.closed.Store(true) + return l.listenConn.Close() +} diff --git a/internal/conn/ws_conn.go b/internal/conn/ws_conn.go index a555852c4..c2515918c 100644 --- a/internal/conn/ws_conn.go +++ b/internal/conn/ws_conn.go @@ -12,6 +12,7 @@ import ( "go.uber.org/zap" ) +// wsConn represents a WebSocket connection to relay(io.Copy) type wsConn struct { conn net.Conn isServer bool @@ -29,7 +30,7 @@ func (c *wsConn) Read(b []byte) (n int, err error) { } if header.Length > int64(cap(c.buf)) { zap.S().Warnf("ws payload size:%d is larger than buffer size:%d", header.Length, cap(c.buf)) - c.buf = make([]byte, header.Length) + return 0, fmt.Errorf("buffer size:%d too small to transport ws payload size:%d", len(b), header.Length) } payload := c.buf[:header.Length] _, err = io.ReadFull(c.conn, payload) diff --git a/internal/constant/constant.go b/internal/constant/constant.go index 1eebadb92..0b2050362 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -6,7 +6,9 @@ type RelayType string var ( // allow change in test - IdleTimeOut = 10 * time.Second + // TODO Set to Relay Config + ReadTimeOut = 5 * time.Second + IdleTimeOut = 30 * time.Second Version = "1.1.5-dev" GitBranch string @@ -20,24 +22,17 @@ const ( SniffTimeOut = 300 * time.Millisecond - SmuxGCDuration = 30 * time.Second - SmuxMaxAliveDuration = 10 * time.Minute - SmuxMaxStreamCnt = 5 - - // todo add udp buffer size + // todo,support config in relay config BUFFER_POOL_SIZE = 1024 // support 512 connections - BUFFER_SIZE = 20 * 1024 // 20KB the maximum packet size of shadowsocks is about 16 KiB + BUFFER_SIZE = 40 * 1024 // 40KB ,the maximum packet size of shadowsocks is about 16 KiB so this is enough + UDPBufSize = 1500 // use default max mtu 1500 ) // relay type const ( - // tcp relay - RelayTypeRaw RelayType = "raw" - RelayTypeMTCP RelayType = "mtcp" - + // direct relay + RelayTypeRaw RelayType = "raw" // ws relay - RelayTypeWS RelayType = "ws" - RelayTypeMWS RelayType = "mws" - RelayTypeWSS RelayType = "wss" - RelayTypeMWSS RelayType = "mwss" + RelayTypeWS RelayType = "ws" + RelayTypeWSS RelayType = "wss" ) diff --git a/internal/relay/conf/cfg.go b/internal/relay/conf/cfg.go index 00d28f5ce..831a9fe24 100644 --- a/internal/relay/conf/cfg.go +++ b/internal/relay/conf/cfg.go @@ -23,24 +23,47 @@ type WSConfig struct { RemoteAddr string `json:"remote_addr,omitempty"` } +func (w *WSConfig) Clone() *WSConfig { + return &WSConfig{ + Path: w.Path, + RemoteAddr: w.RemoteAddr, + } +} + +type Options struct { + WSConfig *WSConfig `json:"ws_config,omitempty"` + EnableUDP bool `json:"enable_udp,omitempty"` + EnableMultipathTCP bool `json:"enable_multipath_tcp,omitempty"` + + MaxConnection int `json:"max_connection,omitempty"` + BlockedProtocols []string `json:"blocked_protocols,omitempty"` + MaxReadRateKbps int64 `json:"max_read_rate_kbps,omitempty"` +} + +func (o *Options) Clone() *Options { + opt := &Options{ + EnableUDP: o.EnableUDP, + EnableMultipathTCP: o.EnableMultipathTCP, + } + if o.WSConfig != nil { + opt.WSConfig = o.WSConfig.Clone() + } + return opt +} + type Config struct { Label string `json:"label,omitempty"` Listen string `json:"listen"` ListenType constant.RelayType `json:"listen_type"` TransportType constant.RelayType `json:"transport_type"` - TCPRemotes []string `json:"tcp_remotes"` - UDPRemotes []string `json:"udp_remotes"` + TCPRemotes []string `json:"tcp_remotes"` // TODO rename to remotes - MaxConnection int `json:"max_connection,omitempty"` - BlockedProtocols []string `json:"blocked_protocols,omitempty"` - MaxReadRateKbps int64 `json:"max_read_rate_kbps,omitempty"` - - WSConfig *WSConfig `json:"ws_config,omitempty"` + Options *Options `json:"options,omitempty"` } func (r *Config) GetWSHandShakePath() string { - if r.WSConfig != nil && r.WSConfig.Path != "" { - return r.WSConfig.Path + if r.Options != nil && r.Options.WSConfig != nil && r.Options.WSConfig.Path != "" { + return r.Options.WSConfig.Path } return WS_HANDSHAKE_PATH } @@ -50,8 +73,8 @@ func (r *Config) GetWSRemoteAddr(baseAddr string) (string, error) { if err != nil { return "", err } - if r.WSConfig != nil && r.WSConfig.RemoteAddr != "" { - addr += fmt.Sprintf("?%s=%s", WS_QUERY_REMOTE_ADDR, r.WSConfig.RemoteAddr) + if r.Options != nil && r.Options.WSConfig != nil && r.Options.WSConfig.RemoteAddr != "" { + addr += fmt.Sprintf("?%s=%s", WS_QUERY_REMOTE_ADDR, r.Options.WSConfig.RemoteAddr) } return addr, nil } @@ -79,17 +102,7 @@ func (r *Config) Validate() error { } } - for _, addr := range r.UDPRemotes { - if addr == "" { - return fmt.Errorf("invalid udp remote addr:%s", addr) - } - } - - if len(r.TCPRemotes) == 0 && len(r.UDPRemotes) == 0 { - return errors.New("both tcp and udp remotes are empty") - } - - for _, protocol := range r.BlockedProtocols { + for _, protocol := range r.Options.BlockedProtocols { if protocol != ProtocolHTTP && protocol != ProtocolTLS { return fmt.Errorf("invalid blocked protocol:%s", protocol) } @@ -103,11 +116,10 @@ func (r *Config) Clone() *Config { ListenType: r.ListenType, TransportType: r.TransportType, Label: r.Label, + Options: r.Options.Clone(), } new.TCPRemotes = make([]string, len(r.TCPRemotes)) copy(new.TCPRemotes, r.TCPRemotes) - new.UDPRemotes = make([]string, len(r.UDPRemotes)) - copy(new.UDPRemotes, r.UDPRemotes) return new } @@ -121,26 +133,17 @@ func (r *Config) Different(new *Config) bool { if len(r.TCPRemotes) != len(new.TCPRemotes) { return true } - for i, addr := range r.TCPRemotes { if addr != new.TCPRemotes[i] { return true } } - if len(r.UDPRemotes) != len(new.UDPRemotes) { - return true - } - for i, addr := range r.UDPRemotes { - if addr != new.UDPRemotes[i] { - return true - } - } return false } // todo make this shorter and more readable func (r *Config) DefaultLabel() string { - defaultLabel := fmt.Sprintf("", + defaultLabel := fmt.Sprintf("", r.Listen, r.TCPRemotes, r.TransportType) return defaultLabel } @@ -150,6 +153,12 @@ func (r *Config) Adjust() error { r.Label = r.DefaultLabel() zap.S().Debugf("label is empty, set default label:%s", r.Label) } + if r.Options == nil { + r.Options = &Options{ + WSConfig: &WSConfig{}, + EnableMultipathTCP: true, + } + } return nil } @@ -171,20 +180,14 @@ func (r *Config) GetLoggerName() string { func (r *Config) validateType() error { if r.ListenType != constant.RelayTypeRaw && r.ListenType != constant.RelayTypeWS && - r.ListenType != constant.RelayTypeMWS && - r.ListenType != constant.RelayTypeWSS && - r.ListenType != constant.RelayTypeMTCP && - r.ListenType != constant.RelayTypeMWSS { + r.ListenType != constant.RelayTypeWSS { return fmt.Errorf("invalid listen type:%s", r.ListenType) } if r.TransportType != constant.RelayTypeRaw && r.TransportType != constant.RelayTypeWS && - r.TransportType != constant.RelayTypeMWS && - r.TransportType != constant.RelayTypeWSS && - r.TransportType != constant.RelayTypeMTCP && - r.TransportType != constant.RelayTypeMWSS { - return fmt.Errorf("invalid transport type:%s", r.ListenType) + r.TransportType != constant.RelayTypeWSS { + return fmt.Errorf("invalid transport type:%s", r.TransportType) } return nil } diff --git a/internal/relay/relay.go b/internal/relay/relay.go index 2e8829400..48eaeca95 100644 --- a/internal/relay/relay.go +++ b/internal/relay/relay.go @@ -1,6 +1,8 @@ package relay import ( + "context" + "go.uber.org/zap" "github.com/Ehco1996/ehco/internal/cmgr" @@ -33,17 +35,17 @@ func NewRelay(cfg *conf.Config, cmgr cmgr.Cmgr) (*Relay, error) { return r, nil } -func (r *Relay) ListenAndServe() error { +func (r *Relay) ListenAndServe(ctx context.Context) error { errCh := make(chan error) go func() { - r.l.Infof("Start TCP Relay Server:%s", r.cfg.DefaultLabel()) - errCh <- r.relayServer.ListenAndServe() + r.l.Infof("Start Relay Server(%s):%s", r.cfg.ListenType, r.cfg.DefaultLabel()) + errCh <- r.relayServer.ListenAndServe(ctx) }() return <-errCh } func (r *Relay) Close() { - r.l.Infof("Close TCP Relay Server:%s", r.cfg.DefaultLabel()) + r.l.Infof("Close Relay Server:%s", r.cfg.DefaultLabel()) if err := r.relayServer.Close(); err != nil { r.l.Errorf(err.Error()) } diff --git a/internal/relay/server.go b/internal/relay/server.go index cd677de0b..022be14ca 100644 --- a/internal/relay/server.go +++ b/internal/relay/server.go @@ -46,10 +46,10 @@ func NewServer(cfg *config.Config) (*Server, error) { return s, nil } -func (s *Server) startOneRelay(r *Relay) { +func (s *Server) startOneRelay(ctx context.Context, r *Relay) { s.relayM.Store(r.UniqueID(), r) // mute closed network error for tcp server and mute http.ErrServerClosed for http server when config reload - if err := r.ListenAndServe(); err != nil && + if err := r.ListenAndServe(ctx); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { s.l.Errorf("start relay %s meet error: %s", r.UniqueID(), err) s.errCH <- err @@ -68,7 +68,7 @@ func (s *Server) Start(ctx context.Context) error { if err != nil { return err } - go s.startOneRelay(r) + go s.startOneRelay(ctx, r) } if s.cfg.PATH != "" && (s.cfg.ReloadInterval > 0 || len(s.cfg.SubConfigs) > 0) { diff --git a/internal/relay/server_reloader.go b/internal/relay/server_reloader.go index 7a88b8637..231378fc9 100644 --- a/internal/relay/server_reloader.go +++ b/internal/relay/server_reloader.go @@ -1,6 +1,8 @@ package relay import ( + "context" + "github.com/Ehco1996/ehco/internal/glue" "github.com/Ehco1996/ehco/internal/relay/conf" "go.uber.org/zap" @@ -48,7 +50,7 @@ func (s *Server) Reload(force bool) error { s.l.Error("new relay meet error", zap.Error(err)) continue } - go s.startOneRelay(r) + go s.startOneRelay(context.TODO(), r) } else { // when label not change, check if config changed oldCfg, ok := oldRelayCfgM[newCfg.Label] @@ -66,7 +68,7 @@ func (s *Server) Reload(force bool) error { s.l.Error("new relay meet error", zap.Error(err)) continue } - go s.startOneRelay(r) + go s.startOneRelay(context.TODO(), r) } } } diff --git a/internal/transporter/base.go b/internal/transporter/base.go index 22923559a..84efdfa12 100644 --- a/internal/transporter/base.go +++ b/internal/transporter/base.go @@ -18,97 +18,160 @@ import ( "go.uber.org/zap" ) -type baseTransporter struct { - cfg *conf.Config - l *zap.SugaredLogger +var _ RelayServer = &BaseRelayServer{} - cmgr cmgr.Cmgr - tCPRemotes lb.RoundRobin - relayer RelayClient +type BaseRelayServer struct { + cmgr cmgr.Cmgr + cfg *conf.Config + l *zap.SugaredLogger + + remotes lb.RoundRobin + relayer RelayClient } -func NewBaseTransporter(cfg *conf.Config, cmgr cmgr.Cmgr) (*baseTransporter, error) { +func newBaseRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (*BaseRelayServer, error) { relayer, err := newRelayClient(cfg) if err != nil { return nil, err } - return &baseTransporter{ - cfg: cfg, - cmgr: cmgr, - tCPRemotes: cfg.ToTCPRemotes(), - l: zap.S().Named(cfg.GetLoggerName()), - relayer: relayer, + return &BaseRelayServer{ + relayer: relayer, + cfg: cfg, + cmgr: cmgr, + remotes: cfg.ToTCPRemotes(), + l: zap.S().Named(cfg.GetLoggerName()), }, nil } -func (b *baseTransporter) GetTCPListenAddr() (*net.TCPAddr, error) { - return net.ResolveTCPAddr("tcp", b.cfg.Listen) +func (b *BaseRelayServer) RelayTCPConn(ctx context.Context, c net.Conn) error { + remote := b.remotes.Next().Clone() + metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Inc() + defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Dec() + + if err := b.checkConnectionLimit(); err != nil { + return err + } + + var err error + c, err = b.sniffAndBlockProtocol(c) + if err != nil { + return err + } + c = b.applyRateLimit(c) + + rc, err := b.relayer.HandShake(ctx, remote, true) + if err != nil { + return fmt.Errorf("handshake error: %w", err) + } + defer rc.Close() + + b.l.Infof("RelayTCPConn from %s to %s", c.LocalAddr(), remote.Address) + return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_TCP) } -func (b *baseTransporter) GetRemote() *lb.Node { - return b.tCPRemotes.Next() +func (b *BaseRelayServer) RelayUDPConn(ctx context.Context, c net.Conn) error { + remote := b.remotes.Next().Clone() + metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_UDP).Inc() + defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_UDP).Dec() + + rc, err := b.relayer.HandShake(ctx, remote, false) + if err != nil { + return fmt.Errorf("handshake error: %w", err) + } + defer rc.Close() + + b.l.Infof("RelayUDPConn from %s to %s", c.LocalAddr(), remote.Address) + return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_UDP) } -func (b *baseTransporter) RelayTCPConn(c net.Conn, handshakeF TCPHandShakeF) error { - remote := b.GetRemote() - metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Inc() - defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Dec() +func (b *BaseRelayServer) checkConnectionLimit() error { + if b.cfg.Options.MaxConnection > 0 && b.cmgr.CountConnection(cmgr.ConnectionTypeActive) >= b.cfg.Options.MaxConnection { + return fmt.Errorf("relay:%s active connection count exceed limit %d", b.cfg.Label, b.cfg.Options.MaxConnection) + } + return nil +} - // check limit - if b.cfg.MaxConnection > 0 && b.cmgr.CountConnection(cmgr.ConnectionTypeActive) >= b.cfg.MaxConnection { - c.Close() - return fmt.Errorf("relay:%s active connection count exceed limit %d", b.cfg.Label, b.cfg.MaxConnection) +func (b *BaseRelayServer) sniffAndBlockProtocol(c net.Conn) (net.Conn, error) { + if len(b.cfg.Options.BlockedProtocols) == 0 { + return c, nil } - // sniff protocol - if len(b.cfg.BlockedProtocols) > 0 { - buffer := buf.NewPacket() - ctx := context.TODO() - sniffMetadata, err := sniff.PeekStream( - ctx, c, buffer, constant.SniffTimeOut, - sniff.TLSClientHello, sniff.HTTPHost) - if err != nil { - // this mean no protocol sniffed - b.l.Debug("sniff error: %s", err) - } - if sniffMetadata != nil { - b.l.Infof("sniffed protocol: %s", sniffMetadata.Protocol) - for _, p := range b.cfg.BlockedProtocols { - if sniffMetadata.Protocol == p { - c.Close() - return fmt.Errorf("relay:%s want to relay blocked protocol:%s", b.cfg.Label, sniffMetadata.Protocol) - } + buffer := buf.NewPacket() + + ctx, cancel := context.WithTimeout(context.Background(), constant.SniffTimeOut) + defer cancel() + + sniffMetadata, err := sniff.PeekStream(ctx, c, buffer, constant.SniffTimeOut, sniff.TLSClientHello, sniff.HTTPHost) + if err != nil { + b.l.Debugf("sniff error: %s", err) + return c, nil + } + + if sniffMetadata != nil { + b.l.Infof("sniffed protocol: %s", sniffMetadata.Protocol) + for _, p := range b.cfg.Options.BlockedProtocols { + if sniffMetadata.Protocol == p { + return c, fmt.Errorf("relay:%s blocked protocol:%s", b.cfg.Label, sniffMetadata.Protocol) } } - if !buffer.IsEmpty() { - c = bufio.NewCachedConn(c, buffer) - } else { - buffer.Release() - } } - // rate limit - if b.cfg.MaxReadRateKbps > 0 { - c = conn.NewRateLimitedConn(c, b.cfg.MaxReadRateKbps) + if !buffer.IsEmpty() { + return bufio.NewCachedConn(c, buffer), nil + } else { + buffer.Release() } + return c, nil +} - clonedRemote := remote.Clone() - rc, err := handshakeF(clonedRemote) - if err != nil { - return err +func (b *BaseRelayServer) applyRateLimit(c net.Conn) net.Conn { + if b.cfg.Options.MaxReadRateKbps > 0 { + return conn.NewRateLimitedConn(c, b.cfg.Options.MaxReadRateKbps) } - defer rc.Close() + return c +} - b.l.Infof("RelayTCPConn from %s to %s", c.LocalAddr(), remote.Address) - relayConn := conn.NewRelayConn( - b.cfg.Label, c, rc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration)) +func (b *BaseRelayServer) handleRelayConn(c, rc net.Conn, remote *lb.Node, connType string) error { + opts := []conn.RelayConnOption{ + conn.WithLogger(b.l), + conn.WithRemote(remote), + conn.WithConnType(connType), + conn.WithRelayLabel(b.cfg.Label), + conn.WithHandshakeDuration(remote.HandShakeDuration), + } + relayConn := conn.NewRelayConn(c, rc, opts...) b.cmgr.AddConnection(relayConn) defer b.cmgr.RemoveConnection(relayConn) - return relayConn.Transport(remote.Label) + return relayConn.Transport() } -func (b *baseTransporter) HealthCheck(ctx context.Context) (int64, error) { - remote := b.GetRemote().Clone() - err := b.relayer.HealthCheck(ctx, remote) +func (b *BaseRelayServer) HealthCheck(ctx context.Context) (int64, error) { + remote := b.remotes.Next().Clone() + // us tcp handshake to check health + _, err := b.relayer.HandShake(ctx, remote, true) return int64(remote.HandShakeDuration.Milliseconds()), err } + +func (b *BaseRelayServer) Close() error { + return fmt.Errorf("not implemented") +} + +func (b *BaseRelayServer) ListenAndServe(ctx context.Context) error { + return fmt.Errorf("not implemented") +} + +func NewNetDialer(cfg *conf.Config) *net.Dialer { + dialer := &net.Dialer{Timeout: constant.DialTimeOut} + dialer.SetMultipathTCP(cfg.Options.EnableMultipathTCP) + return dialer +} + +func NewTCPListener(ctx context.Context, cfg *conf.Config) (net.Listener, error) { + addr, err := net.ResolveTCPAddr("tcp", cfg.Listen) + if err != nil { + return nil, err + } + lcfg := net.ListenConfig{} + lcfg.SetMultipathTCP(cfg.Options.EnableMultipathTCP) + return lcfg.Listen(ctx, "tcp", addr.String()) +} diff --git a/internal/transporter/interface.go b/internal/transporter/interface.go index 301d78b00..5af79579a 100644 --- a/internal/transporter/interface.go +++ b/internal/transporter/interface.go @@ -11,56 +11,45 @@ import ( "github.com/Ehco1996/ehco/internal/relay/conf" ) -type TCPHandShakeF func(remote *lb.Node) (net.Conn, error) - +// TODO opt this interface type RelayClient interface { - HealthCheck(ctx context.Context, remote *lb.Node) error - TCPHandShake(remote *lb.Node) (net.Conn, error) + HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error) } func newRelayClient(cfg *conf.Config) (RelayClient, error) { switch cfg.TransportType { case constant.RelayTypeRaw: return newRawClient(cfg) - case constant.RelayTypeMTCP: - return newMtcpClient(cfg) case constant.RelayTypeWS: return newWsClient(cfg) - case constant.RelayTypeMWS: - return newMwsClient(cfg) case constant.RelayTypeWSS: return newWssClient(cfg) - case constant.RelayTypeMWSS: - return newMwssClient(cfg) default: return nil, fmt.Errorf("unsupported transport type: %s", cfg.TransportType) } } type RelayServer interface { - ListenAndServe() error + ListenAndServe(ctx context.Context) error Close() error + + RelayTCPConn(ctx context.Context, c net.Conn) error + RelayUDPConn(ctx context.Context, c net.Conn) error HealthCheck(ctx context.Context) (int64, error) // latency in ms } func NewRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (RelayServer, error) { - base, err := NewBaseTransporter(cfg, cmgr) + base, err := newBaseRelayServer(cfg, cmgr) if err != nil { return nil, err } switch cfg.ListenType { case constant.RelayTypeRaw: return newRawServer(base) - case constant.RelayTypeMTCP: - return newMtcpServer(base) case constant.RelayTypeWS: return newWsServer(base) - case constant.RelayTypeMWS: - return newMwsServer(base) case constant.RelayTypeWSS: return newWssServer(base) - case constant.RelayTypeMWSS: - return newMwssServer(base) default: panic("unsupported transport type" + cfg.ListenType) } diff --git a/internal/transporter/mux.go b/internal/transporter/mux.go deleted file mode 100644 index 5bc9d3d58..000000000 --- a/internal/transporter/mux.go +++ /dev/null @@ -1,201 +0,0 @@ -// nolint: errcheck -package transporter - -import ( - "context" - "net" - "sync" - "time" - - "github.com/Ehco1996/ehco/internal/constant" - "github.com/xtaci/smux" - "go.uber.org/zap" -) - -type smuxTransporter struct { - sessionMutex sync.Mutex - - gcTicker *time.Ticker - l *zap.SugaredLogger - - // remote addr -> SessionWithMetrics - sessionM map[string][]*SessionWithMetrics - - initSessionF func(ctx context.Context, addr string) (*smux.Session, error) -} - -type SessionWithMetrics struct { - session *smux.Session - - createdTime time.Time - streamList []*smux.Stream -} - -func (sm *SessionWithMetrics) CanNotServeNewStream() bool { - return sm.session.IsClosed() || - sm.session.NumStreams() >= constant.SmuxMaxStreamCnt || - time.Since(sm.createdTime) > constant.SmuxMaxAliveDuration -} - -func streamDead(s *smux.Stream) bool { - select { - case _, ok := <-s.GetDieCh(): - return !ok // 如果接收到值且通道未关闭,则 Stream 未死 - default: - return true // 如果通道已经关闭,则 Stream 死了 - } -} - -func (sm *SessionWithMetrics) canCloseSession(remoteAddr string, l *zap.SugaredLogger) bool { - for _, s := range sm.streamList { - if !streamDead(s) { - return false - } - l.Debugf("session: %s stream: %d is not dead", remoteAddr, s.ID()) - } - return true -} - -func NewSmuxTransporter( - l *zap.SugaredLogger, - initSessionF func(ctx context.Context, addr string) (*smux.Session, error), -) *smuxTransporter { - tr := &smuxTransporter{ - l: l, - initSessionF: initSessionF, - sessionM: make(map[string][]*SessionWithMetrics), - gcTicker: time.NewTicker(constant.SmuxGCDuration), - } - // start gc thread for close idle sessions - go tr.gc() - return tr -} - -func (tr *smuxTransporter) gc() { - for range tr.gcTicker.C { - tr.sessionMutex.Lock() - for addr, sl := range tr.sessionM { - tr.l.Debugf("start doing gc for remote addr: %s total session count %d", addr, len(sl)) - for idx := range sl { - sm := sl[idx] - if sm.CanNotServeNewStream() && sm.canCloseSession(addr, tr.l) { - tr.l.Debugf("close idle session:%s stream cnt %d", - sm.session.LocalAddr().String(), sm.session.NumStreams()) - sm.session.Close() - } - } - newList := []*SessionWithMetrics{} - for _, s := range sl { - if !s.session.IsClosed() { - newList = append(newList, s) - } - } - tr.sessionM[addr] = newList - tr.l.Debugf("finish gc for remote addr: %s total session count %d", addr, len(sl)) - } - tr.sessionMutex.Unlock() - } -} - -func (tr *smuxTransporter) Dial(ctx context.Context, addr string) (conn net.Conn, err error) { - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - var session *smux.Session - var curSM *SessionWithMetrics - - sessionList := tr.sessionM[addr] - for _, sm := range sessionList { - if sm.CanNotServeNewStream() { - continue - } else { - tr.l.Debugf("use session: %s total stream count: %d remote addr: %s", - sm.session.LocalAddr().String(), sm.session.NumStreams(), addr) - session = sm.session - curSM = sm - break - } - } - // create new one - if session == nil { - session, err = tr.initSessionF(ctx, addr) - if err != nil { - return nil, err - } - sm := &SessionWithMetrics{session: session, createdTime: time.Now(), streamList: []*smux.Stream{}} - sessionList = append(sessionList, sm) - tr.sessionM[addr] = sessionList - curSM = sm - } - - stream, err := session.OpenStream() - if err != nil { - tr.l.Errorf("open stream meet error:%s", err) - session.Close() - return nil, err - } - curSM.streamList = append(curSM.streamList, stream) - return stream, nil -} - -type muxServer interface { - ListenAndServe() error - Accept() (net.Conn, error) - Close() error - mux(net.Conn) -} - -func newMuxServer(listenAddr string, l *zap.SugaredLogger) *muxServerImpl { - return &muxServerImpl{ - errChan: make(chan error, 1), - connChan: make(chan net.Conn, 1024), - listenAddr: listenAddr, - l: l, - } -} - -type muxServerImpl struct { - errChan chan error - connChan chan net.Conn - - listenAddr string - l *zap.SugaredLogger -} - -func (s *muxServerImpl) Accept() (net.Conn, error) { - select { - case conn := <-s.connChan: - return conn, nil - case err := <-s.errChan: - return nil, err - } -} - -func (s *muxServerImpl) mux(conn net.Conn) { - defer conn.Close() - - cfg := smux.DefaultConfig() - cfg.KeepAliveDisabled = true - session, err := smux.Server(conn, cfg) - if err != nil { - s.l.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.listenAddr, err) - return - } - defer session.Close() // nolint: errcheck - - s.l.Debugf("session init %s %s", conn.RemoteAddr(), s.listenAddr) - defer s.l.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.listenAddr) - - for { - stream, err := session.AcceptStream() - if err != nil { - s.l.Errorf("accept stream err: %s", err) - break - } - select { - case s.connChan <- stream: - default: - stream.Close() // nolint: errcheck - s.l.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) - } - } -} diff --git a/internal/transporter/raw.go b/internal/transporter/raw.go index 1890426ba..90531d263 100644 --- a/internal/transporter/raw.go +++ b/internal/transporter/raw.go @@ -3,10 +3,11 @@ package transporter import ( "context" + "errors" "net" "time" - "github.com/Ehco1996/ehco/internal/constant" + "github.com/Ehco1996/ehco/internal/conn" "github.com/Ehco1996/ehco/internal/lb" "github.com/Ehco1996/ehco/internal/metrics" "github.com/Ehco1996/ehco/internal/relay/conf" @@ -26,17 +27,22 @@ type RawClient struct { func newRawClient(cfg *conf.Config) (*RawClient, error) { r := &RawClient{ - l: zap.S().Named("raw"), cfg: cfg, - dialer: &net.Dialer{Timeout: constant.DialTimeOut}, + dialer: NewNetDialer(cfg), + l: zap.S().Named(string(cfg.TransportType)), } - r.dialer.SetMultipathTCP(true) return r, nil } -func (raw *RawClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { +func (raw *RawClient) HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error) { t1 := time.Now() - rc, err := raw.dialer.Dial("tcp", remote.Address) + var rc net.Conn + var err error + if isTCP { + rc, err = raw.dialer.DialContext(ctx, "tcp", remote.Address) + } else { + rc, err = raw.dialer.DialContext(ctx, "udp", remote.Address) + } if err != nil { return nil, err } @@ -46,55 +52,72 @@ func (raw *RawClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { return rc, nil } -func (raw *RawClient) HealthCheck(ctx context.Context, remote *lb.Node) error { - l := zap.S().Named("health-check") - l.Infof("start send req to %s", remote.Address) - c, err := raw.TCPHandShake(remote) - if err != nil { - l.Errorf("send req to %s meet error:%s", remote.Address, err) - return err - } - c.Close() - return nil +type RawServer struct { + *BaseRelayServer + + tcpLis net.Listener + udpLis *conn.UDPListener } -type RawServer struct { - *baseTransporter - lis net.Listener +func newRawServer(bs *BaseRelayServer) (*RawServer, error) { + rs := &RawServer{BaseRelayServer: bs} + + return rs, nil } -func newRawServer(base *baseTransporter) (*RawServer, error) { - addr, err := base.GetTCPListenAddr() - if err != nil { - return nil, err +func (s *RawServer) Close() error { + err := s.tcpLis.Close() + if s.udpLis != nil { + err2 := s.udpLis.Close() + err = errors.Join(err, err2) } - cfg := net.ListenConfig{} - cfg.SetMultipathTCP(true) - lis, err := cfg.Listen(context.TODO(), "tcp", addr.String()) + return err +} + +func (s *RawServer) ListenAndServe(ctx context.Context) error { + ts, err := NewTCPListener(ctx, s.cfg) if err != nil { - return nil, err + return err } - return &RawServer{ - lis: lis, - baseTransporter: base, - }, nil -} + s.tcpLis = ts -func (s *RawServer) Close() error { - return s.lis.Close() -} + if s.cfg.Options != nil && s.cfg.Options.EnableUDP { + udpLis, err := conn.NewUDPListener(ctx, s.cfg.Listen) + if err != nil { + return err + } + s.udpLis = udpLis + } -func (s *RawServer) ListenAndServe() error { + if s.udpLis != nil { + go s.listenUDP(ctx) + } for { - c, err := s.lis.Accept() + c, err := s.tcpLis.Accept() if err != nil { return err } go func(c net.Conn) { defer c.Close() - if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil { - s.l.Errorf("RelayTCPConn error: %s", err.Error()) + if err := s.RelayTCPConn(ctx, c); err != nil { + s.l.Errorf("RelayTCPConn meet error: %s", err.Error()) } }(c) } } + +func (s *RawServer) listenUDP(ctx context.Context) error { + s.l.Infof("Start UDP server at %s", s.cfg.Listen) + for { + c, err := s.udpLis.Accept() + if err != nil { + s.l.Errorf("UDP accept error: %v", err) + return err + } + go func() { + if err := s.RelayUDPConn(ctx, c); err != nil { + s.l.Errorf("RelayUDPConn meet error: %s", err.Error()) + } + }() + } +} diff --git a/internal/transporter/raw_mux.go b/internal/transporter/raw_mux.go deleted file mode 100644 index bc7ed22cd..000000000 --- a/internal/transporter/raw_mux.go +++ /dev/null @@ -1,109 +0,0 @@ -package transporter - -import ( - "context" - "net" - "time" - - "github.com/xtaci/smux" - "go.uber.org/zap" - - "github.com/Ehco1996/ehco/internal/lb" - "github.com/Ehco1996/ehco/internal/metrics" - "github.com/Ehco1996/ehco/internal/relay/conf" -) - -var ( - _ RelayClient = &MtcpClient{} - _ RelayServer = &MtcpServer{} -) - -type MtcpClient struct { - *RawClient - muxTP *smuxTransporter -} - -func newMtcpClient(cfg *conf.Config) (*MtcpClient, error) { - raw, err := newRawClient(cfg) - if err != nil { - return nil, err - } - c := &MtcpClient{RawClient: raw} - c.muxTP = NewSmuxTransporter(zap.S().Named("mtcp"), c.initNewSession) - return c, nil -} - -func (c *MtcpClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) { - rc, err := c.dialer.Dial("tcp", addr) - if err != nil { - return nil, err - } - // stream multiplex - cfg := smux.DefaultConfig() - cfg.KeepAliveDisabled = true - session, err := smux.Client(rc, cfg) - if err != nil { - return nil, err - } - c.l.Infof("init new session to: %s", rc.RemoteAddr()) - return session, nil -} - -func (s *MtcpClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { - t1 := time.Now() - mtcpc, err := s.muxTP.Dial(context.TODO(), remote.Address) - if err != nil { - return nil, err - } - latency := time.Since(t1) - metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds())) - remote.HandShakeDuration = latency - return mtcpc, nil -} - -type MtcpServer struct { - *RawServer - *muxServerImpl -} - -func newMtcpServer(base *baseTransporter) (*MtcpServer, error) { - raw, err := newRawServer(base) - if err != nil { - return nil, err - } - s := &MtcpServer{ - RawServer: raw, - muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mtcp")), - } - - return s, nil -} - -func (s *MtcpServer) ListenAndServe() error { - go func() { - for { - c, err := s.lis.Accept() - if err != nil { - s.errChan <- err - continue - } - go s.mux(c) - } - }() - - for { - conn, e := s.Accept() - if e != nil { - return e - } - go func(c net.Conn) { - if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil { - s.l.Errorf("RelayTCPConn error: %s", err.Error()) - } - }(conn) - } -} - -func (s *MtcpServer) Close() error { - return s.lis.Close() -} diff --git a/internal/transporter/ws.go b/internal/transporter/ws.go index d3e01b51a..bdf18b472 100644 --- a/internal/transporter/ws.go +++ b/internal/transporter/ws.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/http" + "net/url" "time" "github.com/gobwas/ws" @@ -11,7 +12,6 @@ import ( "go.uber.org/zap" "github.com/Ehco1996/ehco/internal/conn" - "github.com/Ehco1996/ehco/internal/constant" "github.com/Ehco1996/ehco/internal/lb" "github.com/Ehco1996/ehco/internal/metrics" "github.com/Ehco1996/ehco/internal/relay/conf" @@ -24,27 +24,48 @@ var ( ) type WsClient struct { - dialer *ws.Dialer - cfg *conf.Config - l *zap.SugaredLogger + dialer *ws.Dialer + cfg *conf.Config + netDialer *net.Dialer + l *zap.SugaredLogger } func newWsClient(cfg *conf.Config) (*WsClient, error) { s := &WsClient{ - cfg: cfg, - l: zap.S().Named(string(cfg.TransportType)), - dialer: &ws.Dialer{Timeout: constant.DialTimeOut}, + cfg: cfg, + netDialer: NewNetDialer(cfg), + l: zap.S().Named(string(cfg.TransportType)), + dialer: &ws.DefaultDialer, // todo config buffer size + } + s.dialer.NetDial = func(ctx context.Context, network, addr string) (net.Conn, error) { + return s.netDialer.Dial(network, addr) } return s, nil } -func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { +func (s *WsClient) addUDPQueryParam(addr string) string { + u, err := url.Parse(addr) + if err != nil { + s.l.Errorf("Failed to parse URL: %v", err) + return addr + } + q := u.Query() + q.Set("type", "udp") + u.RawQuery = q.Encode() + return u.String() +} + +func (s *WsClient) HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error) { t1 := time.Now() addr, err := s.cfg.GetWSRemoteAddr(remote.Address) if err != nil { return nil, err } - wsc, _, _, err := s.dialer.Dial(context.TODO(), addr) + if !isTCP { + addr = s.addUDPQueryParam(addr) + } + + wsc, _, _, err := s.dialer.Dial(ctx, addr) if err != nil { return nil, err } @@ -55,61 +76,50 @@ func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { return c, nil } -func (s *WsClient) HealthCheck(ctx context.Context, remote *lb.Node) error { - l := zap.S().Named("health-check") - l.Infof("start send req to %s", remote.Address) - c, err := s.TCPHandShake(remote) - if err != nil { - l.Errorf("send req to %s meet error:%s", remote.Address, err) - return err - } - c.Close() - return nil -} - type WsServer struct { - *baseTransporter - - e *echo.Echo + *BaseRelayServer httpServer *http.Server } -func newWsServer(base *baseTransporter) (*WsServer, error) { - localTCPAddr, err := base.GetTCPListenAddr() - if err != nil { - return nil, err - } - s := &WsServer{ - baseTransporter: base, - httpServer: &http.Server{ - Addr: localTCPAddr.String(), ReadHeaderTimeout: 30 * time.Second, - }, - } +func newWsServer(bs *BaseRelayServer) (*WsServer, error) { + s := &WsServer{BaseRelayServer: bs} e := web.NewEchoServer() e.Use(web.NginxLogMiddleware(zap.S().Named("ws-server"))) - e.GET("/", echo.WrapHandler(web.MakeIndexF())) - e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest))) - - s.e = e + e.GET(bs.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.handleRequest))) + s.httpServer = &http.Server{Handler: e} return s, nil } -func (s *WsServer) ListenAndServe() error { - return s.e.StartServer(s.httpServer) -} - -func (s *WsServer) Close() error { - return s.e.Close() -} - -func (s *WsServer) HandleRequest(w http.ResponseWriter, req *http.Request) { +func (s *WsServer) handleRequest(w http.ResponseWriter, req *http.Request) { + // todo use bufio.ReadWriter wsc, _, _, err := ws.UpgradeHTTP(req, w) if err != nil { return } + if req.URL.Query().Get("type") == "udp" { + if !s.cfg.Options.EnableUDP { + s.l.Error("udp not support but request with udp type") + wsc.Close() + return + } + err = s.RelayUDPConn(req.Context(), conn.NewWSConn(wsc, true)) + } else { + err = s.RelayTCPConn(req.Context(), conn.NewWSConn(wsc, true)) + } + if err != nil { + s.l.Errorf("handleRequest meet error:%s", err) + } +} - if err := s.RelayTCPConn(conn.NewWSConn(wsc, true), s.relayer.TCPHandShake); err != nil { - s.l.Errorf("RelayTCPConn error: %s", err.Error()) +func (s *WsServer) ListenAndServe(ctx context.Context) error { + listener, err := NewTCPListener(ctx, s.cfg) + if err != nil { + return err } + return s.httpServer.Serve(listener) +} + +func (s *WsServer) Close() error { + return s.httpServer.Close() } diff --git a/internal/transporter/ws_mux.go b/internal/transporter/ws_mux.go deleted file mode 100644 index 5260a5ae4..000000000 --- a/internal/transporter/ws_mux.go +++ /dev/null @@ -1,122 +0,0 @@ -// NOTE CAN NOT use real ws frame to transport smux frame -// err: accept stream err: buffer size:8 too small to transport ws payload size:45 -// so this transport just use ws protocol to handshake and then use smux protocol to transport -package transporter - -import ( - "context" - "net" - "net/http" - "time" - - "github.com/gobwas/ws" - "github.com/labstack/echo/v4" - "github.com/xtaci/smux" - - "github.com/Ehco1996/ehco/internal/lb" - "github.com/Ehco1996/ehco/internal/metrics" - "github.com/Ehco1996/ehco/internal/relay/conf" -) - -var ( - _ RelayClient = &MwsClient{} - _ RelayServer = &MwsServer{} - _ muxServer = &MwsServer{} -) - -type MwsClient struct { - *WssClient - - muxTP *smuxTransporter -} - -func newMwsClient(cfg *conf.Config) (*MwsClient, error) { - wc, err := newWssClient(cfg) - if err != nil { - return nil, err - } - c := &MwsClient{WssClient: wc} - c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession) - return c, nil -} - -func (c *MwsClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) { - rc, _, _, err := c.dialer.Dial(ctx, addr) - if err != nil { - return nil, err - } - // stream multiplex - cfg := smux.DefaultConfig() - cfg.KeepAliveDisabled = true - session, err := smux.Client(rc, cfg) - if err != nil { - return nil, err - } - c.l.Infof("init new session to: %s", rc.RemoteAddr()) - return session, nil -} - -func (s *MwsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { - t1 := time.Now() - addr, err := s.cfg.GetWSRemoteAddr(remote.Address) - if err != nil { - return nil, err - } - mwssc, err := s.muxTP.Dial(context.TODO(), addr) - if err != nil { - return nil, err - } - latency := time.Since(t1) - metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds())) - remote.HandShakeDuration = latency - return mwssc, nil -} - -type MwsServer struct { - *WsServer - *muxServerImpl -} - -func newMwsServer(base *baseTransporter) (*MwsServer, error) { - wsServer, err := newWsServer(base) - if err != nil { - return nil, err - } - s := &MwsServer{ - WsServer: wsServer, - muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")), - } - s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest))) - return s, nil -} - -func (s *MwsServer) ListenAndServe() error { - go func() { - s.errChan <- s.e.StartServer(s.httpServer) - }() - - for { - conn, e := s.Accept() - if e != nil { - return e - } - go func(c net.Conn) { - if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil { - s.l.Errorf("RelayTCPConn error: %s", err.Error()) - } - }(conn) - } -} - -func (s *MwsServer) HandleRequest(w http.ResponseWriter, r *http.Request) { - c, _, _, err := ws.UpgradeHTTP(r, w) - if err != nil { - s.l.Error(err) - return - } - s.mux(c) -} - -func (s *MwsServer) Close() error { - return s.e.Close() -} diff --git a/internal/transporter/wss.go b/internal/transporter/wss.go index 74391f078..dc157a614 100644 --- a/internal/transporter/wss.go +++ b/internal/transporter/wss.go @@ -1,6 +1,9 @@ package transporter import ( + "context" + "crypto/tls" + "github.com/Ehco1996/ehco/internal/relay/conf" mytls "github.com/Ehco1996/ehco/internal/tls" ) @@ -28,12 +31,19 @@ type WssServer struct { *WsServer } -func newWssServer(base *baseTransporter) (*WssServer, error) { - wsServer, err := newWsServer(base) +func newWssServer(bs *BaseRelayServer) (*WssServer, error) { + wsServer, err := newWsServer(bs) if err != nil { return nil, err } - // insert tls config - wsServer.httpServer.TLSConfig = mytls.DefaultTLSConfig return &WssServer{WsServer: wsServer}, nil } + +func (s *WssServer) ListenAndServe(ctx context.Context) error { + listener, err := NewTCPListener(ctx, s.cfg) + if err != nil { + return err + } + tlsListener := tls.NewListener(listener, mytls.DefaultTLSConfig) + return s.httpServer.Serve(tlsListener) +} diff --git a/internal/transporter/wss_mux.go b/internal/transporter/wss_mux.go deleted file mode 100644 index 5c38925d0..000000000 --- a/internal/transporter/wss_mux.go +++ /dev/null @@ -1,122 +0,0 @@ -// NOTE CAN NOT use real ws frame to transport smux frame -// err: accept stream err: buffer size:8 too small to transport ws payload size:45 -// so this transport just use ws protocol to handshake and then use smux protocol to transport -package transporter - -import ( - "context" - "net" - "net/http" - "time" - - "github.com/gobwas/ws" - "github.com/labstack/echo/v4" - "github.com/xtaci/smux" - - "github.com/Ehco1996/ehco/internal/lb" - "github.com/Ehco1996/ehco/internal/metrics" - "github.com/Ehco1996/ehco/internal/relay/conf" -) - -var ( - _ RelayClient = &MwssClient{} - _ RelayServer = &MwssServer{} - _ muxServer = &MwssServer{} -) - -type MwssClient struct { - *WssClient - - muxTP *smuxTransporter -} - -func newMwssClient(cfg *conf.Config) (*MwssClient, error) { - wc, err := newWssClient(cfg) - if err != nil { - return nil, err - } - c := &MwssClient{WssClient: wc} - c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession) - return c, nil -} - -func (c *MwssClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) { - rc, _, _, err := c.dialer.Dial(ctx, addr) - if err != nil { - return nil, err - } - // stream multiplex - cfg := smux.DefaultConfig() - cfg.KeepAliveDisabled = true - session, err := smux.Client(rc, cfg) - if err != nil { - return nil, err - } - c.l.Infof("init new session to: %s", rc.RemoteAddr()) - return session, nil -} - -func (s *MwssClient) TCPHandShake(remote *lb.Node) (net.Conn, error) { - t1 := time.Now() - addr, err := s.cfg.GetWSRemoteAddr(remote.Address) - if err != nil { - return nil, err - } - mwssc, err := s.muxTP.Dial(context.TODO(), addr) - if err != nil { - return nil, err - } - latency := time.Since(t1) - metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds())) - remote.HandShakeDuration = latency - return mwssc, nil -} - -type MwssServer struct { - *WssServer - *muxServerImpl -} - -func newMwssServer(base *baseTransporter) (*MwssServer, error) { - wssServer, err := newWssServer(base) - if err != nil { - return nil, err - } - s := &MwssServer{ - WssServer: wssServer, - muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")), - } - s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest))) - return s, nil -} - -func (s *MwssServer) ListenAndServe() error { - go func() { - s.errChan <- s.e.StartServer(s.httpServer) - }() - - for { - conn, e := s.Accept() - if e != nil { - return e - } - go func(c net.Conn) { - if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil { - s.l.Errorf("RelayTCPConn error: %s", err.Error()) - } - }(conn) - } -} - -func (s *MwssServer) HandleRequest(w http.ResponseWriter, r *http.Request) { - c, _, _, err := ws.UpgradeHTTP(r, w) - if err != nil { - s.l.Error(err) - return - } - s.mux(c) -} - -func (s *MwssServer) Close() error { - return s.e.Close() -} diff --git a/internal/web/templates/connection.html b/internal/web/templates/connection.html index c32f8109a..6a66c4ec8 100644 --- a/internal/web/templates/connection.html +++ b/internal/web/templates/connection.html @@ -28,6 +28,7 @@

ALL Connections: {{.AllCount}}

Relay Label + Type Flow Stats Time @@ -37,7 +38,8 @@

ALL Connections: {{.AllCount}}

{{range .ConnectionList}} {{.RelayLabel}} - {{.Flow}} + {{.ConnType}} + {{.GetFlow}} {{.Stats}} {{.GetTime}} diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index 333636573..63a25db1f 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -6,11 +6,13 @@ import ( // 全局pool var ( - BufferPool *BytePool + BufferPool *BytePool + UDPBufferPool *BytePool ) func init() { BufferPool = NewBytePool(constant.BUFFER_POOL_SIZE, constant.BUFFER_SIZE) + UDPBufferPool = NewBytePool(constant.BUFFER_POOL_SIZE, constant.UDPBufSize) } // BytePool implements a leaky pool of []byte in the form of a bounded channel diff --git a/pkg/metric_reader/reader.go b/pkg/metric_reader/reader.go index d6e227da9..7e8c72a96 100644 --- a/pkg/metric_reader/reader.go +++ b/pkg/metric_reader/reader.go @@ -35,7 +35,7 @@ func (b *readerImpl) parsePingInfo(metricMap map[string]*dto.MetricFamily, nm *N metric, ok := metricMap["ehco_ping_response_duration_seconds"] if !ok { // this metric is optional when enable_ping = false - zap.S().Warn("ping metric not found") + zap.S().Debug("ping metric not found") return nil } for _, m := range metric.Metric { diff --git a/pkg/sub/clash.go b/pkg/sub/clash.go index e87520347..f22cac171 100644 --- a/pkg/sub/clash.go +++ b/pkg/sub/clash.go @@ -179,7 +179,9 @@ func (c *ClashSub) ToRelayConfigs(listenHost string) ([]*relay_cfg.Config, error } rc.TCPRemotes = append(rc.TCPRemotes, remote) if proxy.UDP { - rc.UDPRemotes = append(rc.UDPRemotes, remote) + rc.Options = &relay_cfg.Options{ + EnableUDP: true, + } } } relayConfigs = append(relayConfigs, rc) diff --git a/pkg/sub/clash_types.go b/pkg/sub/clash_types.go index f40f9e488..543d52134 100644 --- a/pkg/sub/clash_types.go +++ b/pkg/sub/clash_types.go @@ -139,7 +139,9 @@ func (p *Proxies) ToRelayConfig(listenHost string, listenPort string, newName st TCPRemotes: []string{remoteAddr}, } if p.UDP { - r.UDPRemotes = []string{remoteAddr} + r.Options = &relay_cfg.Options{ + EnableUDP: true, + } } if err := r.Validate(); err != nil { return nil, err diff --git a/test/relay_test.go b/test/relay_test.go index 64f63644a..2bb57df70 100644 --- a/test/relay_test.go +++ b/test/relay_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "net" "os" "testing" "time" @@ -28,8 +27,7 @@ const ( ECHO_PORT = 9002 ECHO_SERVER = "0.0.0.0:9002" - RAW_LISTEN = "0.0.0.0:1234" - RAW_LISTEN_WITH_MAX_CONNECTION = "0.0.0.0:2234" + RAW_LISTEN = "0.0.0.0:1234" WS_LISTEN = "0.0.0.0:1235" WS_REMOTE = "ws://0.0.0.0:2000" @@ -38,22 +36,15 @@ const ( WSS_LISTEN = "0.0.0.0:1236" WSS_REMOTE = "wss://0.0.0.0:2001" WSS_SERVER = "0.0.0.0:2001" - - MWSS_LISTEN = "0.0.0.0:1237" - MWSS_REMOTE = "wss://0.0.0.0:2002" - MWSS_SERVER = "0.0.0.0:2002" - - MTCP_LISTEN = "0.0.0.0:1238" - MTCP_REMOTE = "0.0.0.0:2003" - MTCP_SERVER = "0.0.0.0:2003" - - MWS_LISTEN = "0.0.0.0:1239" - MWS_REMOTE = "ws://0.0.0.0:2004" - MSS_SERVER = "0.0.0.0:2004" ) func TestMain(m *testing.M) { // Setup + + // change the idle timeout to 1 second to make connection close faster in test + constant.IdleTimeOut = time.Second + constant.ReadTimeOut = time.Second + _ = log.InitGlobalLogger("debug") _ = tls.InitTlsCfg() @@ -78,38 +69,35 @@ func TestMain(m *testing.M) { func startRelayServers() []*relay.Relay { cfg := config.Config{ - PATH: "", RelayConfigs: []*conf.Config{ - // raw cfg + // raw { Listen: RAW_LISTEN, ListenType: constant.RelayTypeRaw, TCPRemotes: []string{ECHO_SERVER}, - UDPRemotes: []string{ECHO_SERVER}, - TransportType: constant.RelayTypeRaw, - }, - // raw cfg with max connection - { - Listen: RAW_LISTEN_WITH_MAX_CONNECTION, - ListenType: constant.RelayTypeRaw, - TCPRemotes: []string{ECHO_SERVER}, - UDPRemotes: []string{ECHO_SERVER}, TransportType: constant.RelayTypeRaw, - MaxConnection: 1, + Options: &conf.Options{ + EnableUDP: true, + }, }, - // ws { Listen: WS_LISTEN, ListenType: constant.RelayTypeRaw, TCPRemotes: []string{WS_REMOTE}, TransportType: constant.RelayTypeWS, + Options: &conf.Options{ + EnableUDP: true, + }, }, { Listen: WS_SERVER, ListenType: constant.RelayTypeWS, TCPRemotes: []string{ECHO_SERVER}, TransportType: constant.RelayTypeRaw, + Options: &conf.Options{ + EnableUDP: true, + }, }, // wss @@ -118,65 +106,30 @@ func startRelayServers() []*relay.Relay { ListenType: constant.RelayTypeRaw, TCPRemotes: []string{WSS_REMOTE}, TransportType: constant.RelayTypeWSS, + Options: &conf.Options{ + EnableUDP: true, + }, }, { Listen: WSS_SERVER, ListenType: constant.RelayTypeWSS, TCPRemotes: []string{ECHO_SERVER}, TransportType: constant.RelayTypeRaw, - }, - - // mwss - { - Listen: MWSS_LISTEN, - ListenType: constant.RelayTypeRaw, - TCPRemotes: []string{MWSS_REMOTE}, - TransportType: constant.RelayTypeMWSS, - }, - { - Listen: MWSS_SERVER, - ListenType: constant.RelayTypeMWSS, - TCPRemotes: []string{ECHO_SERVER}, - TransportType: constant.RelayTypeRaw, - }, - - // mtcp - { - Listen: MTCP_LISTEN, - ListenType: constant.RelayTypeRaw, - TCPRemotes: []string{MTCP_REMOTE}, - TransportType: constant.RelayTypeMTCP, - }, - { - Listen: MTCP_SERVER, - ListenType: constant.RelayTypeMTCP, - TCPRemotes: []string{ECHO_SERVER}, - TransportType: constant.RelayTypeRaw, - }, - - // mws - { - Listen: MWS_LISTEN, - ListenType: constant.RelayTypeRaw, - TCPRemotes: []string{MWS_REMOTE}, - TransportType: constant.RelayTypeMWS, - }, - { - Listen: MSS_SERVER, - ListenType: constant.RelayTypeMWS, - TCPRemotes: []string{ECHO_SERVER}, - TransportType: constant.RelayTypeRaw, + Options: &conf.Options{ + EnableUDP: true, + }, }, }, } var servers []*relay.Relay for _, c := range cfg.RelayConfigs { + c.Adjust() r, err := relay.NewRelay(c, cmgr.NewCmgr(cmgr.DummyConfig)) if err != nil { zap.S().Fatal(err) } - go r.ListenAndServe() + go r.ListenAndServe(context.TODO()) servers = append(servers, r) } @@ -194,16 +147,14 @@ func TestRelay(t *testing.T) { {"Raw", RAW_LISTEN, "raw"}, {"WS", WS_LISTEN, "ws"}, {"WSS", WSS_LISTEN, "wss"}, - {"MWSS", MWSS_LISTEN, "mwss"}, - {"MTCP", MTCP_LISTEN, "mtcp"}, - {"MWS", MWS_LISTEN, "mws"}, } for _, tc := range testCases { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { t.Parallel() - testRelayCommon(t, tc.address, tc.protocol, false) + testTCPRelay(t, tc.address, tc.protocol, false) + testUDPRelay(t, tc.address, false) }) } } @@ -214,21 +165,22 @@ func TestRelayConcurrent(t *testing.T) { address string concurrency int }{ - {"MWSS", MWSS_LISTEN, 10}, - {"MTCP", MTCP_LISTEN, 10}, - {"MWS", MWS_LISTEN, 10}, + {"Raw", RAW_LISTEN, 10}, + {"WS", WS_LISTEN, 10}, + {"WSS", WSS_LISTEN, 10}, } for _, tc := range testCases { tc := tc // capture range variable t.Run(tc.name, func(t *testing.T) { t.Parallel() - testRelayCommon(t, tc.address, tc.name, true, tc.concurrency) + testTCPRelay(t, tc.address, tc.name, true, tc.concurrency) + testUDPRelay(t, tc.address, true, tc.concurrency) }) } } -func testRelayCommon(t *testing.T, address, protocol string, concurrent bool, concurrency ...int) { +func testTCPRelay(t *testing.T, address, protocol string, concurrent bool, concurrency ...int) { t.Helper() msg := []byte("hello") @@ -264,40 +216,42 @@ func testRelayCommon(t *testing.T, address, protocol string, concurrent bool, co t.Logf("Test TCP over %s done!", protocol) } -func TestRelayWithMaxConnectionCount(t *testing.T) { - msg := []byte("hello") - - // First connection will be accepted - go func() { - err := echo.EchoTcpMsgLong(msg, time.Second, RAW_LISTEN_WITH_MAX_CONNECTION) - require.NoError(t, err, "First connection should be accepted") - }() - - // Wait for first connection - time.Sleep(time.Second) - - // Second connection should be rejected - err := echo.EchoTcpMsgLong(msg, time.Second, RAW_LISTEN_WITH_MAX_CONNECTION) - require.Error(t, err, "Second connection should be rejected") -} +func testUDPRelay(t *testing.T, address string, concurrent bool, concurrency ...int) { + t.Helper() + msg := []byte("hello udp") -func TestRelayWithDeadline(t *testing.T) { - logger, _ := zap.NewDevelopment() - msg := []byte("hello") - conn, err := net.Dial("tcp", RAW_LISTEN) - if err != nil { - logger.Sugar().Fatal(err) - } - defer conn.Close() - if _, err := conn.Write(msg); err != nil { - logger.Sugar().Fatal(err) + runTest := func() error { + res := echo.SendUdpMsg(msg, address) + if !bytes.Equal(msg, res) { + return fmt.Errorf("response mismatch: got %s, want %s", res, msg) + } + return nil } - buf := make([]byte, len(msg)) - constant.IdleTimeOut = time.Second // change for test - time.Sleep(constant.IdleTimeOut) - _, err = conn.Read(buf) - if err != nil { - logger.Sugar().Fatal("need error here") + if concurrent { + n := 10 + if len(concurrency) > 0 { + n = concurrency[0] + } + g, ctx := errgroup.WithContext(context.Background()) + for i := 0; i < n; i++ { + g.Go(func() error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return runTest() + } + }) + } + require.NoError(t, g.Wait(), "Concurrent test failed") + } else { + require.NoError(t, runTest(), "Single test failed") } + t.Logf("Test UDP over %s done!", address) +} + +func TestRelayIdleTimeout(t *testing.T) { + err := echo.EchoTcpMsgLong([]byte("hello"), time.Second, RAW_LISTEN) + require.Error(t, err, "Connection should be rejected") }