@@ -6,237 +6,226 @@ import (
6
6
"time"
7
7
)
8
8
9
+ type clientConn struct {
10
+ sconn net.Conn
11
+ conn net.Conn
12
+ istcp bool
13
+ remoteAddress net.Addr // udp used
14
+ }
15
+
16
+ func (c * clientConn ) Read (b []byte ) (n int , err error ) {
17
+ if n , err = c .conn .Read (b ); err != nil || c .istcp {
18
+ return
19
+ }
20
+ d , err := ReadRequestUDP (b [0 :n ])
21
+ if err != nil {
22
+ return 0 , err
23
+ }
24
+ n = copy (b , d .Data )
25
+ return n , nil
26
+ }
27
+
28
+ func (c * clientConn ) Write (b []byte ) (n int , err error ) {
29
+ if c .istcp {
30
+ return c .conn .Write (b )
31
+ }
32
+
33
+ a , h , p , err := ParseAddress (c .remoteAddress .String ())
34
+ if err != nil {
35
+ return 0 , err
36
+ }
37
+
38
+ d := NewRequestUDP (a , h , p , b )
39
+ b1 := d .Bytes ()
40
+ if n , err = c .conn .Write (b1 ); err != nil {
41
+ return 0 , err
42
+ }
43
+ if len (b1 ) != n {
44
+ return 0 , errors .New ("not write full" )
45
+ }
46
+ return len (b ), nil
47
+ }
48
+
49
+ func (c * clientConn ) Close () error {
50
+ if ! c .istcp {
51
+ // 使用udp代理后,需要关闭原有的tcp连接
52
+ c .sconn .Close ()
53
+ }
54
+ return c .conn .Close ()
55
+ }
56
+
57
+ func (c * clientConn ) LocalAddr () net.Addr {
58
+ return c .conn .LocalAddr ()
59
+ }
60
+
61
+ func (c * clientConn ) RemoteAddr () net.Addr {
62
+ return c .remoteAddress
63
+ }
64
+
65
+ func (c * clientConn ) SetDeadline (t time.Time ) error {
66
+ return c .conn .SetDeadline (t )
67
+ }
68
+
69
+ func (c * clientConn ) SetReadDeadline (t time.Time ) error {
70
+ return c .conn .SetReadDeadline (t )
71
+ }
72
+
73
+ func (c * clientConn ) SetWriteDeadline (t time.Time ) error {
74
+ return c .conn .SetWriteDeadline (t )
75
+ }
76
+
9
77
type Client struct {
10
78
Server string
11
79
Username string
12
80
Password string
13
- DialTCP func (net string , laddr , raddr * net.TCPAddr ) (* net.TCPConn , error )
14
- // On cmd UDP, let server control the tcp and udp connection relationship
15
- tcpConn * net.TCPConn
16
- udpConn * net.UDPConn
17
- remoteAddress net.Addr //udp used
81
+ DialTCP func (network string , laddr , raddr * net.TCPAddr ) (net.Conn , error )
82
+ }
83
+
84
+ func (c * Client ) dial (laddr * net.TCPAddr ) (conn net.Conn , err error ) {
85
+ raddr , err := net .ResolveTCPAddr ("tcp" , c .Server )
86
+ if err != nil {
87
+ return nil , err
88
+ }
89
+ if c .DialTCP != nil {
90
+ return c .DialTCP ("tcp" , laddr , raddr )
91
+ }
92
+ return net .DialTCP ("tcp" , laddr , raddr )
18
93
}
19
94
20
95
func (c * Client ) Dial (network , addr string ) (net.Conn , error ) {
21
96
return c .DialWithLocalAddr (network , "" , addr )
22
97
}
23
98
24
99
func (c Client ) DialWithLocalAddr (network , src , dst string ) (net.Conn , error ) {
25
-
26
100
var err error
27
-
101
+
28
102
var la * net.TCPAddr
29
103
if src != "" {
30
104
la , err = net .ResolveTCPAddr ("tcp" , src )
31
105
if err != nil {
32
106
return nil , err
33
107
}
34
108
}
35
-
36
- if err := c . negotiate ( la ); err != nil {
109
+ conn , err := c . dial ( la )
110
+ if err != nil {
37
111
return nil , err
38
112
}
39
-
40
- if network == "tcp" {
41
-
42
- if c .remoteAddress == nil {
43
- c .remoteAddress , err = net .ResolveTCPAddr ("tcp" , dst )
44
- if err != nil {
45
- return nil , err
46
- }
47
- }
48
-
49
- a , h , p , err := parseAddress (dst )
50
- if err != nil {
51
- return nil , err
52
- }
53
-
54
- r := newRequestTCP (CmdConnect , a , h , p )
55
- if _ , err := c .request (r ); err != nil {
56
- return nil , err
57
- }
58
-
59
- return & c , nil
60
- }else if network == "udp" {
61
-
62
- if c .remoteAddress == nil {
63
- c .remoteAddress , err = net .ResolveUDPAddr ("udp" , dst )
64
- if err != nil {
65
- return nil , err
66
- }
67
- }
68
-
69
- laddr := & net.UDPAddr {
70
- IP : c .tcpConn .LocalAddr ().(* net.TCPAddr ).IP ,
71
- Port : c .tcpConn .LocalAddr ().(* net.TCPAddr ).Port ,
72
- Zone : c .tcpConn .LocalAddr ().(* net.TCPAddr ).Zone ,
73
- }
74
-
75
- a , h , p , err := parseAddress (laddr .String ())
76
- if err != nil {
77
- return nil , err
78
- }
79
-
80
- //告诉服务器,我发起的UDP本地地址
81
- r := newRequestTCP (CmdUDP , a , h , p )
82
- rp , err := c .request (r )
83
- if err != nil {
84
- return nil , err
85
- }
86
-
87
- //服务给的端口地址
88
- raddr , err := net .ResolveUDPAddr ("udp" , rp .Address ())
89
- if err != nil {
113
+
114
+ nt , ok := conn .(* Negotiate )
115
+ if ! ok || ! nt .done {
116
+ if err = negotiate (conn , c .Username , c .Password ); err != nil {
90
117
return nil , err
91
118
}
92
-
93
- c .udpConn , err = net .DialUDP ("udp" , laddr , raddr )
119
+ }
120
+ raddr , err := net .ResolveTCPAddr ("tcp" , dst )
121
+ if err != nil {
122
+ return nil , err
123
+ }
124
+ switch network {
125
+ case "tcp" :
126
+ c .requestTCP (conn , dst )
127
+ return & clientConn {conn : conn , istcp : true , remoteAddress : raddr }, nil
128
+ case "udp" :
129
+ udpConn , err := c .requestUDP (conn , dst )
94
130
if err != nil {
131
+ conn .Close ()
95
132
return nil , err
96
133
}
97
- return & c , nil
134
+ return & clientConn { sconn : conn , conn : udpConn , remoteAddress : raddr } , nil
98
135
}
99
136
return nil , errors .New ("unsupport network" )
100
137
}
101
138
102
- func (c * Client ) Read (b []byte ) (int , error ) {
103
- if c .udpConn == nil {
104
- return c .tcpConn .Read (b )
105
- }
106
- n , err := c .udpConn .Read (b )
107
- if err != nil {
108
- return 0 , err
109
- }
110
- d , err := readRequestUDP (b [0 :n ])
139
+ func (c * Client ) requestTCP (conn net.Conn , dst string ) (* ReplyTCP , error ) {
140
+ a , h , p , err := ParseAddress (dst )
111
141
if err != nil {
112
- return 0 , err
142
+ return nil , err
113
143
}
114
- n = copy (b , d .Data )
115
- return n , nil
144
+
145
+ r := NewRequestTCP (CmdConnect , a , h , p )
146
+ return c .request (conn , r )
116
147
}
117
148
118
- func (c * Client ) Write (b []byte ) (int , error ) {
119
- if c .udpConn == nil {
120
- return c .tcpConn .Write (b )
121
- }
122
- a , h , p , err := parseAddress (c .remoteAddress .String ())
149
+ func (c * Client ) requestUDP (conn net.Conn , dst string ) (net.Conn , error ) {
150
+ // 告诉代理服务器,我使用这个地址端口向代理服务器发起UDP请求
151
+ laddr := & net.UDPAddr {
152
+ IP : conn .LocalAddr ().(* net.TCPAddr ).IP ,
153
+ Port : conn .LocalAddr ().(* net.TCPAddr ).Port ,
154
+ Zone : conn .LocalAddr ().(* net.TCPAddr ).Zone ,
155
+ }
156
+ a , h , p , err := ParseAddress (laddr .String ())
123
157
if err != nil {
124
- return 0 , err
158
+ return nil , err
125
159
}
126
-
127
- d := newReplyUDP (a , h , p , b )
128
- b1 := d .Bytes ()
129
- n , err := c .udpConn .Write (b1 )
160
+ r := NewRequestTCP (CmdUDP , a , h , p )
161
+ rp , err := c .request (conn , r )
130
162
if err != nil {
131
- return 0 , err
132
- }
133
- if len (b1 ) != n {
134
- return 0 , errors .New ("not write full" )
163
+ return nil , err
135
164
}
136
- return len (b ), nil
137
- }
138
165
139
- func (c * Client ) Close () error {
140
- if c .udpConn == nil {
141
- return c .tcpConn .Close ()
142
- }
143
- if c .tcpConn != nil {
144
- c .tcpConn .Close ()
166
+ // 服务给监听的端口地址
167
+ raddr , err := net .ResolveUDPAddr ("udp" , rp .Address ())
168
+ if err != nil {
169
+ return nil , err
145
170
}
146
- return c . udpConn . Close ( )
171
+ return net . DialUDP ( "udp" , laddr , raddr )
147
172
}
148
173
149
- func (c * Client ) LocalAddr () net.Addr {
150
- if c . udpConn = = nil {
151
- return c . tcpConn . LocalAddr ()
174
+ func (c * Client ) request ( conn net.Conn , r * RequestTCP ) ( * ReplyTCP , error ) {
175
+ if _ , err := r . WriteTo ( conn ); err ! = nil {
176
+ return nil , err
152
177
}
153
- return c .udpConn .LocalAddr ()
154
- }
155
-
156
- func (c * Client ) RemoteAddr () net.Addr {
157
- return c .remoteAddress
158
- }
159
-
160
- func (c * Client ) SetDeadline (t time.Time ) error {
161
- if c .udpConn == nil {
162
- return c .tcpConn .SetDeadline (t )
178
+ rp , err := ReadReplyTCP (conn )
179
+ if err != nil {
180
+ return nil , err
181
+ }
182
+ if rp .Rep != RepSuccess {
183
+ return nil , errors .New ("host unreachable" )
163
184
}
164
- return c . udpConn . SetDeadline ( t )
185
+ return rp , nil
165
186
}
166
187
167
- func (c * Client ) SetReadDeadline (t time.Time ) error {
168
- if c .udpConn == nil {
169
- return c .tcpConn .SetReadDeadline (t )
170
- }
171
- return c .udpConn .SetReadDeadline (t )
188
+ type Negotiate struct {
189
+ net.Conn
190
+ done bool
172
191
}
173
192
174
- func (c * Client ) SetWriteDeadline (t time.Time ) error {
175
- if c .udpConn == nil {
176
- return c .tcpConn .SetWriteDeadline (t )
177
- }
178
- return c .udpConn .SetWriteDeadline (t )
193
+ // 验证账号
194
+ func (T * Negotiate ) Auth (username , password string ) error {
195
+ T .done = true
196
+ return negotiate (T , username , password )
179
197
}
180
198
181
- func (c * Client ) negotiate (laddr * net.TCPAddr ) error {
182
- raddr , err := net .ResolveTCPAddr ("tcp" , c .Server )
183
- if err != nil {
184
- return err
185
- }
186
- if c .DialTCP != nil {
187
- c .tcpConn , err = c .DialTCP ("tcp" , laddr , raddr )
188
- }else {
189
- c .tcpConn , err = net .DialTCP ("tcp" , laddr , raddr )
190
- }
191
- if err != nil {
192
- return err
193
- }
194
-
199
+ func negotiate (conn net.Conn , username , passwod string ) error {
195
200
m := MethodNone
196
- if c . Username != "" {
201
+ if username != "" {
197
202
m = MethodUsernamePassword
198
203
}
199
- rq := newNegotiateWriteRequest ([]byte {m })
200
- if _ , err := rq .WriteTo (c . tcpConn ); err != nil {
204
+ rq := NewNegotiateMethodRequest ([]byte {m })
205
+ if _ , err := rq .WriteTo (conn ); err != nil {
201
206
return err
202
207
}
203
-
204
- rp , err := negotiateReadReply (c .tcpConn )
208
+ rp , err := ReadNegotiateMethodReply (conn )
205
209
if err != nil {
206
210
return err
207
211
}
208
212
if rp .Method != m {
209
- return errors .New ("Unsupport method" )
213
+ return errors .New ("unsupport method" )
210
214
}
211
-
212
215
if m == MethodUsernamePassword {
213
- urq := newNegotiateAuthRequest ([]byte (c . Username ), []byte (c . Password ))
214
- if _ , err := urq .WriteTo (c . tcpConn ); err != nil {
216
+ urq := NewNegotiateAuthRequest ([]byte (username ), []byte (passwod ))
217
+ if _ , err := urq .WriteTo (conn ); err != nil {
215
218
return err
216
219
}
217
-
218
- urp , err := negotiateAuthReply ( c . tcpConn )
220
+
221
+ urp , err := ReadNegotiateAuthReply ( conn )
219
222
if err != nil {
220
223
return err
221
224
}
222
-
225
+
223
226
if urp .Status != UserPassStatusSuccess {
224
227
return ErrUserPassAuth
225
228
}
226
229
}
227
230
return nil
228
231
}
229
-
230
- func (c * Client ) request (r * RequestTCP ) (* ReplyTCP , error ) {
231
- if _ , err := r .WriteTo (c .tcpConn ); err != nil {
232
- return nil , err
233
- }
234
- rp , err := readReplyTCP (c .tcpConn )
235
- if err != nil {
236
- return nil , err
237
- }
238
- if rp .Rep != RepSuccess {
239
- return nil , errors .New ("Host unreachable" )
240
- }
241
- return rp , nil
242
- }
0 commit comments