Skip to content

Commit 49f7d1e

Browse files
committed
公开更多函数
1 parent cef3584 commit 49f7d1e

File tree

8 files changed

+702
-522
lines changed

8 files changed

+702
-522
lines changed

client.go

+151-162
Original file line numberDiff line numberDiff line change
@@ -6,237 +6,226 @@ import (
66
"time"
77
)
88

9+
type clientConn struct {
10+
sconn net.Conn
11+
conn net.Conn
12+
istcp bool
13+
remoteAddress net.Addr // udp used
14+
}
15+
16+
func (c *clientConn) Read(b []byte) (n int, err error) {
17+
if n, err = c.conn.Read(b); err != nil || c.istcp {
18+
return
19+
}
20+
d, err := ReadRequestUDP(b[0:n])
21+
if err != nil {
22+
return 0, err
23+
}
24+
n = copy(b, d.Data)
25+
return n, nil
26+
}
27+
28+
func (c *clientConn) Write(b []byte) (n int, err error) {
29+
if c.istcp {
30+
return c.conn.Write(b)
31+
}
32+
33+
a, h, p, err := ParseAddress(c.remoteAddress.String())
34+
if err != nil {
35+
return 0, err
36+
}
37+
38+
d := NewRequestUDP(a, h, p, b)
39+
b1 := d.Bytes()
40+
if n, err = c.conn.Write(b1); err != nil {
41+
return 0, err
42+
}
43+
if len(b1) != n {
44+
return 0, errors.New("not write full")
45+
}
46+
return len(b), nil
47+
}
48+
49+
func (c *clientConn) Close() error {
50+
if !c.istcp {
51+
// 使用udp代理后,需要关闭原有的tcp连接
52+
c.sconn.Close()
53+
}
54+
return c.conn.Close()
55+
}
56+
57+
func (c *clientConn) LocalAddr() net.Addr {
58+
return c.conn.LocalAddr()
59+
}
60+
61+
func (c *clientConn) RemoteAddr() net.Addr {
62+
return c.remoteAddress
63+
}
64+
65+
func (c *clientConn) SetDeadline(t time.Time) error {
66+
return c.conn.SetDeadline(t)
67+
}
68+
69+
func (c *clientConn) SetReadDeadline(t time.Time) error {
70+
return c.conn.SetReadDeadline(t)
71+
}
72+
73+
func (c *clientConn) SetWriteDeadline(t time.Time) error {
74+
return c.conn.SetWriteDeadline(t)
75+
}
76+
977
type Client struct {
1078
Server string
1179
Username string
1280
Password string
13-
DialTCP func(net string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error)
14-
// On cmd UDP, let server control the tcp and udp connection relationship
15-
tcpConn *net.TCPConn
16-
udpConn *net.UDPConn
17-
remoteAddress net.Addr //udp used
81+
DialTCP func(network string, laddr, raddr *net.TCPAddr) (net.Conn, error)
82+
}
83+
84+
func (c *Client) dial(laddr *net.TCPAddr) (conn net.Conn, err error) {
85+
raddr, err := net.ResolveTCPAddr("tcp", c.Server)
86+
if err != nil {
87+
return nil, err
88+
}
89+
if c.DialTCP != nil {
90+
return c.DialTCP("tcp", laddr, raddr)
91+
}
92+
return net.DialTCP("tcp", laddr, raddr)
1893
}
1994

2095
func (c *Client) Dial(network, addr string) (net.Conn, error) {
2196
return c.DialWithLocalAddr(network, "", addr)
2297
}
2398

2499
func (c Client) DialWithLocalAddr(network, src, dst string) (net.Conn, error) {
25-
26100
var err error
27-
101+
28102
var la *net.TCPAddr
29103
if src != "" {
30104
la, err = net.ResolveTCPAddr("tcp", src)
31105
if err != nil {
32106
return nil, err
33107
}
34108
}
35-
36-
if err := c.negotiate(la); err != nil {
109+
conn, err := c.dial(la)
110+
if err != nil {
37111
return nil, err
38112
}
39-
40-
if network == "tcp" {
41-
42-
if c.remoteAddress == nil {
43-
c.remoteAddress, err = net.ResolveTCPAddr("tcp", dst)
44-
if err != nil {
45-
return nil, err
46-
}
47-
}
48-
49-
a, h, p, err := parseAddress(dst)
50-
if err != nil {
51-
return nil, err
52-
}
53-
54-
r := newRequestTCP(CmdConnect, a, h, p)
55-
if _, err := c.request(r); err != nil {
56-
return nil, err
57-
}
58-
59-
return &c, nil
60-
}else if network == "udp" {
61-
62-
if c.remoteAddress == nil {
63-
c.remoteAddress, err = net.ResolveUDPAddr("udp", dst)
64-
if err != nil {
65-
return nil, err
66-
}
67-
}
68-
69-
laddr := &net.UDPAddr{
70-
IP: c.tcpConn.LocalAddr().(*net.TCPAddr).IP,
71-
Port: c.tcpConn.LocalAddr().(*net.TCPAddr).Port,
72-
Zone: c.tcpConn.LocalAddr().(*net.TCPAddr).Zone,
73-
}
74-
75-
a, h, p, err := parseAddress(laddr.String())
76-
if err != nil {
77-
return nil, err
78-
}
79-
80-
//告诉服务器,我发起的UDP本地地址
81-
r := newRequestTCP(CmdUDP, a, h, p)
82-
rp, err := c.request(r)
83-
if err != nil {
84-
return nil, err
85-
}
86-
87-
//服务给的端口地址
88-
raddr, err := net.ResolveUDPAddr("udp", rp.Address())
89-
if err != nil {
113+
114+
nt, ok := conn.(*Negotiate)
115+
if !ok || !nt.done {
116+
if err = negotiate(conn, c.Username, c.Password); err != nil {
90117
return nil, err
91118
}
92-
93-
c.udpConn, err = net.DialUDP("udp", laddr, raddr)
119+
}
120+
raddr, err := net.ResolveTCPAddr("tcp", dst)
121+
if err != nil {
122+
return nil, err
123+
}
124+
switch network {
125+
case "tcp":
126+
c.requestTCP(conn, dst)
127+
return &clientConn{conn: conn, istcp: true, remoteAddress: raddr}, nil
128+
case "udp":
129+
udpConn, err := c.requestUDP(conn, dst)
94130
if err != nil {
131+
conn.Close()
95132
return nil, err
96133
}
97-
return &c, nil
134+
return &clientConn{sconn: conn, conn: udpConn, remoteAddress: raddr}, nil
98135
}
99136
return nil, errors.New("unsupport network")
100137
}
101138

102-
func (c *Client) Read(b []byte) (int, error) {
103-
if c.udpConn == nil {
104-
return c.tcpConn.Read(b)
105-
}
106-
n, err := c.udpConn.Read(b)
107-
if err != nil {
108-
return 0, err
109-
}
110-
d, err := readRequestUDP(b[0:n])
139+
func (c *Client) requestTCP(conn net.Conn, dst string) (*ReplyTCP, error) {
140+
a, h, p, err := ParseAddress(dst)
111141
if err != nil {
112-
return 0, err
142+
return nil, err
113143
}
114-
n = copy(b, d.Data)
115-
return n, nil
144+
145+
r := NewRequestTCP(CmdConnect, a, h, p)
146+
return c.request(conn, r)
116147
}
117148

118-
func (c *Client) Write(b []byte) (int, error) {
119-
if c.udpConn == nil {
120-
return c.tcpConn.Write(b)
121-
}
122-
a, h, p, err := parseAddress(c.remoteAddress.String())
149+
func (c *Client) requestUDP(conn net.Conn, dst string) (net.Conn, error) {
150+
// 告诉代理服务器,我使用这个地址端口向代理服务器发起UDP请求
151+
laddr := &net.UDPAddr{
152+
IP: conn.LocalAddr().(*net.TCPAddr).IP,
153+
Port: conn.LocalAddr().(*net.TCPAddr).Port,
154+
Zone: conn.LocalAddr().(*net.TCPAddr).Zone,
155+
}
156+
a, h, p, err := ParseAddress(laddr.String())
123157
if err != nil {
124-
return 0, err
158+
return nil, err
125159
}
126-
127-
d := newReplyUDP(a, h, p, b)
128-
b1 := d.Bytes()
129-
n, err := c.udpConn.Write(b1)
160+
r := NewRequestTCP(CmdUDP, a, h, p)
161+
rp, err := c.request(conn, r)
130162
if err != nil {
131-
return 0, err
132-
}
133-
if len(b1) != n {
134-
return 0, errors.New("not write full")
163+
return nil, err
135164
}
136-
return len(b), nil
137-
}
138165

139-
func (c *Client) Close() error {
140-
if c.udpConn == nil {
141-
return c.tcpConn.Close()
142-
}
143-
if c.tcpConn != nil {
144-
c.tcpConn.Close()
166+
// 服务给监听的端口地址
167+
raddr, err := net.ResolveUDPAddr("udp", rp.Address())
168+
if err != nil {
169+
return nil, err
145170
}
146-
return c.udpConn.Close()
171+
return net.DialUDP("udp", laddr, raddr)
147172
}
148173

149-
func (c *Client) LocalAddr() net.Addr {
150-
if c.udpConn == nil {
151-
return c.tcpConn.LocalAddr()
174+
func (c *Client) request(conn net.Conn, r *RequestTCP) (*ReplyTCP, error) {
175+
if _, err := r.WriteTo(conn); err != nil {
176+
return nil, err
152177
}
153-
return c.udpConn.LocalAddr()
154-
}
155-
156-
func (c *Client) RemoteAddr() net.Addr {
157-
return c.remoteAddress
158-
}
159-
160-
func (c *Client) SetDeadline(t time.Time) error {
161-
if c.udpConn == nil {
162-
return c.tcpConn.SetDeadline(t)
178+
rp, err := ReadReplyTCP(conn)
179+
if err != nil {
180+
return nil, err
181+
}
182+
if rp.Rep != RepSuccess {
183+
return nil, errors.New("host unreachable")
163184
}
164-
return c.udpConn.SetDeadline(t)
185+
return rp, nil
165186
}
166187

167-
func (c *Client) SetReadDeadline(t time.Time) error {
168-
if c.udpConn == nil {
169-
return c.tcpConn.SetReadDeadline(t)
170-
}
171-
return c.udpConn.SetReadDeadline(t)
188+
type Negotiate struct {
189+
net.Conn
190+
done bool
172191
}
173192

174-
func (c *Client) SetWriteDeadline(t time.Time) error {
175-
if c.udpConn == nil {
176-
return c.tcpConn.SetWriteDeadline(t)
177-
}
178-
return c.udpConn.SetWriteDeadline(t)
193+
// 验证账号
194+
func (T *Negotiate) Auth(username, password string) error {
195+
T.done = true
196+
return negotiate(T, username, password)
179197
}
180198

181-
func (c *Client) negotiate(laddr *net.TCPAddr) error {
182-
raddr, err := net.ResolveTCPAddr("tcp", c.Server)
183-
if err != nil {
184-
return err
185-
}
186-
if c.DialTCP != nil {
187-
c.tcpConn, err = c.DialTCP("tcp", laddr, raddr)
188-
}else{
189-
c.tcpConn, err = net.DialTCP("tcp", laddr, raddr)
190-
}
191-
if err != nil {
192-
return err
193-
}
194-
199+
func negotiate(conn net.Conn, username, passwod string) error {
195200
m := MethodNone
196-
if c.Username != "" {
201+
if username != "" {
197202
m = MethodUsernamePassword
198203
}
199-
rq := newNegotiateWriteRequest([]byte{m})
200-
if _, err := rq.WriteTo(c.tcpConn); err != nil {
204+
rq := NewNegotiateMethodRequest([]byte{m})
205+
if _, err := rq.WriteTo(conn); err != nil {
201206
return err
202207
}
203-
204-
rp, err := negotiateReadReply(c.tcpConn)
208+
rp, err := ReadNegotiateMethodReply(conn)
205209
if err != nil {
206210
return err
207211
}
208212
if rp.Method != m {
209-
return errors.New("Unsupport method")
213+
return errors.New("unsupport method")
210214
}
211-
212215
if m == MethodUsernamePassword {
213-
urq := newNegotiateAuthRequest([]byte(c.Username), []byte(c.Password))
214-
if _, err := urq.WriteTo(c.tcpConn); err != nil {
216+
urq := NewNegotiateAuthRequest([]byte(username), []byte(passwod))
217+
if _, err := urq.WriteTo(conn); err != nil {
215218
return err
216219
}
217-
218-
urp, err := negotiateAuthReply(c.tcpConn)
220+
221+
urp, err := ReadNegotiateAuthReply(conn)
219222
if err != nil {
220223
return err
221224
}
222-
225+
223226
if urp.Status != UserPassStatusSuccess {
224227
return ErrUserPassAuth
225228
}
226229
}
227230
return nil
228231
}
229-
230-
func (c *Client) request(r *RequestTCP) (*ReplyTCP, error) {
231-
if _, err := r.WriteTo(c.tcpConn); err != nil {
232-
return nil, err
233-
}
234-
rp, err := readReplyTCP(c.tcpConn)
235-
if err != nil {
236-
return nil, err
237-
}
238-
if rp.Rep != RepSuccess {
239-
return nil, errors.New("Host unreachable")
240-
}
241-
return rp, nil
242-
}

0 commit comments

Comments
 (0)