-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathserver.go
288 lines (248 loc) · 7.75 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
package drpc
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"reflect"
"strings"
"sync"
"time"
"github.com/devhg/drpc/codec"
)
const (
MagicNumber = 0x3bef5c
connected = "200 Connected to drpc"
defaultRPCPath = "/_drpc_"
defaultDebugPath = "/debug/drpc"
)
// 报文形式
// | Option{MagicNumber: xxx, CodecType: xxx} | Header{ServiceMethod ...} | Body interface{} |
// | <------ 固定 JSON 编码 ------> | <------- 编码方式由 CodeType 决定 ------->|
// 在一次连接中,Option 固定在报文的最开始,Header 和 Body 可以有多个,即报文可能是这样的。
// | Option | Header1 | Body1 | Header2 | Body2 | ...
type Option struct {
MagicNumber int
CodecType codec.Type
ConnectTimeout time.Duration // 0 means no limit
HandleTimeout time.Duration
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10,
}
// Server represents an RPC Server.
type Server struct {
serviceMap sync.Map
}
func NewServer() *Server {
return &Server{}
}
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()
// Accept accepts connections on the listener and serves requests
// for each incoming connection.
func Accept(lis net.Listener) {
DefaultServer.Accept(lis)
}
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
}
go server.ServeConn(conn)
}
}
// Register publishes the receiver's methods in the DefaultServer
func Register(rcvr interface{}) error {
return DefaultServer.Register(rcvr)
}
func (server *Server) Register(rcvr interface{}) error {
s := newService(rcvr)
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
return errors.New("rpc server: service already defined: " + s.name)
}
return nil
}
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: decode Option error:", err)
return
}
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x\n", opt.MagicNumber)
return
}
codecFunc := codec.NewCodecFuncMap[opt.CodecType]
if codecFunc == nil {
log.Printf("rpc server: invalid codec type %s\n", opt.CodecType)
return
}
cc := codecFunc(conn)
server.ServeCodec(cc, opt.HandleTimeout)
}
var invalidRequest = struct{}{}
func (server *Server) ServeCodec(cc codec.Codec, timeout time.Duration) {
// var sending *sync.Mutex // make sure to send a complete response
// var wg *sync.WaitGroup
// 不能使用上面的,因为这样传参会传输nil !!!!本函数直接使用的话是延迟初始化,没有问题。
sending := new(sync.Mutex) // make sure to send a complete response
wg := new(sync.WaitGroup) // wait until all request are handled
for {
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break
}
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
go server.handleRequest(cc, req, sending, wg, timeout)
}
wg.Wait()
_ = cc.Close()
}
type request struct {
h *codec.Header // header of request
argv reflect.Value
replyv reflect.Value
mTyp *methodType
servci *service
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
header, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: header}
req.servci, req.mTyp, err = server.findService(header.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mTyp.newArgv()
req.replyv = req.mTyp.newRetv()
// make sure that argvi is a pointer interface{},
// because ReadBody needs a pointer as parameter
argvInter := req.argv.Interface()
if req.argv.Kind() != reflect.Ptr {
argvInter = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvInter); err != nil {
log.Println("rpc server: read body error:", err)
return req, err
}
return req, nil
}
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
func (server *Server) findService(serviceMethod string) (servci *service, mTyp *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = errors.New("rpc server: service/method request illegal-formed: " + serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
serviceInter, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc server: can't find service: " + serviceName)
return
}
servci = serviceInter.(*service)
mTyp = servci.method[methodName]
if mTyp == nil {
err = errors.New("rpc server: can't find method: " + methodName)
return
}
return
}
// 这里需要确保 sendResponse 仅调用一次,因此将整个过程拆分为 called 和 sent 两个阶段,
// 在这段代码中只会发生如下两种情况:
// 1) called 信道接收到消息,代表处理没有超时,继续执行 sendResponse。
// 2) time.After() 先于 called 接收到消息,说明处理已经超时,called 和 sent 都将被阻塞。
// 在 case <-time.After(timeout) 处调用 sendResponse。
func (server *Server) handleRequest(cc codec.Codec, req *request,
sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called, sent := make(chan struct{}), make(chan struct{})
go func() {
err := req.servci.call(req.mTyp, req.argv, req.replyv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()
// block if HandleTimeout is equal with zero
if timeout == 0 {
<-called
<-sent
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header,
body interface{}, sending *sync.Mutex) {
// TODO: 并发问题,保证发送过程是原子的
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
// Server support HTTP Protocol
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(w, "405 must CONNECT\n")
return
}
// The Hijacker interface is implemented by ResponseWriters that allow
// an HTTP handler to take over the connection.
// 新建一个handler来接管此连接
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Println("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
// 这里是给新的 connection 发送回复
_, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages on defaultRPCPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (server *Server) HandleHTTP() {
http.Handle(defaultRPCPath, server)
http.Handle(defaultDebugPath, debugHTTP{server})
log.Println("rpc server: debug path:", defaultDebugPath)
}
// HandleHTTP is a convenient approach for default server to register HTTP handlers
func HandleHTTP() {
DefaultServer.HandleHTTP()
}