Skip to content

Commit

Permalink
ws: use ws frame to transport data
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 committed May 30, 2024
1 parent 3726533 commit 6675ace
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 11 deletions.
1 change: 0 additions & 1 deletion internal/metrics/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,3 @@ func (pg *PingGroup) Run() {
time.Sleep(splay)
}
}

6 changes: 4 additions & 2 deletions internal/transporter/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
latency := time.Since(t1)
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
remote.HandShakeDuration = latency
return wsc, nil
c := newWsConn(wsc, false)
return c, nil
}

type WsServer struct {
Expand Down Expand Up @@ -90,7 +91,8 @@ func (s *WsServer) HandleRequest(w http.ResponseWriter, req *http.Request) {
if err != nil {
return
}
if err := s.RelayTCPConn(wsc, s.relayer.TCPHandShake); err != nil {
c := newWsConn(wsc, true)
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}
82 changes: 82 additions & 0 deletions internal/transporter/ws_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package transporter

import (
"fmt"
"io"
"net"
"time"

"github.com/Ehco1996/ehco/pkg/buffer"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)

type wsConn struct {
conn net.Conn
isServer bool
buf []byte
}

func newWsConn(conn net.Conn, isServer bool) *wsConn {
return &wsConn{conn: conn, isServer: isServer, buf: buffer.BufferPool.Get()}
}

func (c *wsConn) Read(b []byte) (n int, err error) {
header, err := ws.ReadHeader(c.conn)
if err != nil {
return 0, err
}
if header.Length > int64(cap(c.buf)) {
c.buf = make([]byte, header.Length)
}
payload := c.buf[:header.Length]
_, err = io.ReadFull(c.conn, payload)
if err != nil {
return 0, err
}
if header.Masked {
ws.Cipher(payload, header.Mask, 0)
}
if len(payload) > len(b) {
return 0, fmt.Errorf("buffer too small to transport ws msg")
}
copy(b, payload)
return len(payload), nil
}

func (c *wsConn) Write(b []byte) (n int, err error) {
if c.isServer {
err = wsutil.WriteServerBinary(c.conn, b)
} else {
err = wsutil.WriteClientBinary(c.conn, b)
}
if err != nil {
return 0, err
}
return len(b), nil
}

func (c *wsConn) Close() error {
defer buffer.BufferPool.Put(c.buf)
return c.conn.Close()
}

func (c *wsConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}

func (c *wsConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}

func (c *wsConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}

func (c *wsConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}

func (c *wsConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
70 changes: 70 additions & 0 deletions internal/transporter/ws_conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package transporter

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/gobwas/ws"
"github.com/stretchr/testify/assert"
)

func TestClientConn_ReadWrite(t *testing.T) {
data := []byte("hello")

// Create a WebSocket server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
go func() {
defer conn.Close()
wsc := newWsConn(conn, true)

buf := make([]byte, 1024)
for {
n, err := wsc.Read(buf)
if err != nil {
return
}
assert.Equal(t, len(data), n)
assert.Equal(t, "hello", string(buf[:n]))
_, err = wsc.Write(buf[:n])
if err != nil {
return
}
}
}()
}))
defer server.Close()

// Create a WebSocket client
addr, err := url.Parse(server.URL)
if err != nil {
t.Fatal(err)
}
conn, _, _, err := ws.DefaultDialer.Dial(context.TODO(), "ws://"+addr.Host)
if err != nil {
t.Fatal(err)
}
defer conn.Close()

wsClientConn := newWsConn(conn, false)
for i := 0; i < 3; i++ {
// test write
n, err := wsClientConn.Write(data)
assert.NoError(t, err, "test cnt %d", i)
assert.Equal(t, len(data), n, "test cnt %d", i)

// test read
buf := make([]byte, 100)
n, err = wsClientConn.Read(buf)
assert.NoError(t, err, "test cnt %d", i)
assert.Equal(t, len(data), n, "test cnt %d", i)
assert.Equal(t, "hello", string(buf[:n]), "test cnt %d", i)
}
}
23 changes: 16 additions & 7 deletions test/cmd/tcp_client/main.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
package main

import "github.com/Ehco1996/ehco/test/echo"
import (
"time"

"github.com/Ehco1996/ehco/test/echo"
)

func main() {
msg := []byte("hello")

echoServerAddr := "127.0.0.1:2333"
println("real echo server at:", echoServerAddr)

// start ehco real server
// go run cmd/ehco/main.go -l 0.0.0.0:2234 -r 0.0.0.0:2333

relayAddr := "127.0.0.1:2234"
println("real echo server at:", echoServerAddr, "relay addr:", relayAddr)

ret := echo.SendTcpMsg(msg, relayAddr)
println(string(ret))
if string(ret) != "hello" {
panic("relay short failed")
}
println("test short conn success, hello sended and received")

if err := echo.EchoTcpMsgLong(msg, time.Second, relayAddr); err != nil {
panic("relay long failed:" + err.Error())
}
println("test long conn success")
}
3 changes: 2 additions & 1 deletion test/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func echo(conn net.Conn) {
logger.Error(err.Error())
return
}
println("echo server receive", string(buf[:i]))
_, err = conn.Write(buf[:i])
if err != nil {
logger.Error(err.Error())
Expand Down Expand Up @@ -135,7 +136,7 @@ func EchoTcpMsgLong(msg []byte, sleepTime time.Duration, address string) error {
return err
}
if string(buf[:n]) != string(msg) {
return fmt.Errorf("msg not equal")
return fmt.Errorf("msg not equal at %d send:%s receive:%s n:%d", i, msg, buf[:n], n)
}
// to fake a long connection
time.Sleep(sleepTime)
Expand Down
16 changes: 16 additions & 0 deletions test/echo/ws.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"relay_configs": [
{
"listen": "127.0.0.1:2234",
"listen_type": "raw",
"transport_type": "ws",
"tcp_remotes": ["ws://0.0.0.0:2443"]
},
{
"listen": "127.0.0.1:2443",
"listen_type": "ws",
"transport_type": "raw",
"tcp_remotes": ["127.0.0.1:2333"]
}
]
}

0 comments on commit 6675ace

Please sign in to comment.