Skip to content

Commit

Permalink
Fix pkg#195 Handler wrappers deal with errors
Browse files Browse the repository at this point in the history
TLDR; have the Handler wrappers generate the response packets for the
errors instead of returning the errors.

Errors from the handers was returned up the stack to the request-server
handler call. In cases of errors the packet for the error was then
generated using the passed in packet (the incoming packet).

For file read/write/list operations the incoming packet data (includeing
ID) is put in a queue/channel so that it can be handled by the same file
handling code returned by the handler.

Due to gorouting scheduling, in some cases the packets in the channel
won't line up with the incoming packets. That is the worker that took an
incoming packet with one ID might respond with a packet with a different
ID. This doesn't matter as the reply order of packets doesn't matter.

BUT when this response packet that doesn't match IDs with the incoming
packet system returns an error, it uses the incoming packets ID to
generate the error packet. So the error response would use that ID and
the packet from the queue, once processed, would also use that ID
(because on success  it uses the ID from the packet in the queue).

This patch fixes the issue by having the code that works with the
packets, either incoming or from the queue, generate the error packet.
So there is no chance of them getting out of sync.
  • Loading branch information
eikenb committed Aug 20, 2017
1 parent bc6b56a commit cffd2fa
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 72 deletions.
17 changes: 4 additions & 13 deletions request-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
} else {
request = requestFromPacket(
&sshFxpStatPacket{ID: pkt.id(), Path: request.Filepath})
rpkt = rs.handle(request, pkt)
rpkt = request.handle(rs.Handlers)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
Expand All @@ -154,7 +154,7 @@ func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
&sshFxpSetstatPacket{ID: pkt.id(), Path: request.Filepath,
Flags: pkt.Flags, Attrs: pkt.Attrs,
})
rpkt = rs.handle(request, pkt)
rpkt = request.handle(rs.Handlers)
}
case hasHandle:
handle := pkt.getHandle()
Expand All @@ -163,11 +163,11 @@ func (rs *RequestServer) packetWorker(pktChan chan requestPacket) error {
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
rpkt = rs.handle(request, pkt)
rpkt = request.handle(rs.Handlers)
}
case hasPath:
request := requestFromPacket(pkt)
rpkt = rs.handle(request, pkt)
rpkt = request.handle(rs.Handlers)
default:
return errors.Errorf("unexpected packet type %T", pkt)
}
Expand Down Expand Up @@ -200,15 +200,6 @@ func cleanPath(path string) string {
return cleanSlashPath
}

func (rs *RequestServer) handle(request Request, pkt requestPacket) responsePacket {
// fmt.Println("Request Method: ", request.Method)
rpkt, err := request.handle(rs.Handlers)
if err != nil {
rpkt = statusFromError(pkt, err)
}
return rpkt
}

// Wrap underlying connection methods to use packetManager
func (rs *RequestServer) sendPacket(m encoding.BinaryMarshaler) error {
if pkt, ok := m.(responsePacket); ok {
Expand Down
99 changes: 57 additions & 42 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@ type state struct {
}

type packet_data struct {
id uint32
_id uint32
data []byte
length uint32
offset int64
}

func (pd packet_data) id() uint32 {
return pd._id
}

// New Request initialized based on packet data
func requestFromPacket(pkt hasPath) Request {
method := requestMethod(pkt)
Expand All @@ -72,6 +76,10 @@ func NewRequest(method, path string) Request {
return request
}

func (r Request) id() uint32 {
return r.pkt_id
}

// Returns current offset for file list, and sets next offset
func (r Request) lsNext(offset int64) (current int64) {
r.stateLock.RLock()
Expand Down Expand Up @@ -152,7 +160,7 @@ func (r *Request) popPacket() packet_data {
}

// called from worker to handle packet/request
func (r *Request) handle(handlers Handlers) (responsePacket, error) {
func (r *Request) handle(handlers Handlers) responsePacket {
switch r.Method {
case "Get":
return fileget(handlers.FileGet, r)
Expand All @@ -163,83 +171,82 @@ func (r *Request) handle(handlers Handlers) (responsePacket, error) {
case "List", "Stat", "Readlink":
return filelist(handlers.FileList, r)
default:
return nil, errors.Errorf("unexpected method: %s", r.Method)
return statusFromError(r,
errors.Errorf("unexpected method: %s", r.Method))
}
}

// wrap FileReader handler
func fileget(h FileReader, r *Request) (responsePacket, error) {
func fileget(h FileReader, r *Request) responsePacket {
var err error
reader := r.getReader()
pd := r.popPacket()
if reader == nil {
reader, err = h.Fileread(r)
if err != nil {
return nil, err
return statusFromError(pd, err)
}
r.setFileState(reader)
}

pd := r.popPacket()
data := make([]byte, clamp(pd.length, maxTxPacket))
n, err := reader.ReadAt(data, pd.offset)
// only return EOF erro if no data left to read
if err != nil && (err != io.EOF || n == 0) {
return nil, err
return statusFromError(pd, err)
}
return &sshFxpDataPacket{
ID: pd.id,
ID: pd.id(),
Length: uint32(n),
Data: data[:n],
}, nil
}
}

// wrap FileWriter handler
func fileput(h FileWriter, r *Request) (responsePacket, error) {
func fileput(h FileWriter, r *Request) responsePacket {
var err error
writer := r.getWriter()
if writer == nil {
writer, err = h.Filewrite(r)
if err != nil {
return nil, err
return statusFromError(r, err)
}
r.setFileState(writer)
}

pd := r.popPacket()
_, err = writer.WriteAt(pd.data, pd.offset)
if err != nil {
return nil, err
return statusFromError(pd, err)
}
return &sshFxpStatusPacket{
ID: pd.id,
ID: pd.id(),
StatusError: StatusError{
Code: ssh_FX_OK,
},
}, nil
}}
}

// wrap FileCmder handler
func filecmd(h FileCmder, r *Request) (responsePacket, error) {
func filecmd(h FileCmder, r *Request) responsePacket {
err := h.Filecmd(r)
if err != nil {
return nil, err
return statusFromError(r, err)
}
return &sshFxpStatusPacket{
ID: r.pkt_id,
StatusError: StatusError{
Code: ssh_FX_OK,
},
}, nil
}}
}

// wrap FileLister handler
func filelist(h FileLister, r *Request) (responsePacket, error) {
func filelist(h FileLister, r *Request) responsePacket {
var err error
lister := r.getLister()
if lister == nil {
lister, err = h.Filelist(r)
if err != nil {
return nil, err
return statusFromError(r, err)
}
r.setFileState(lister)
}
Expand All @@ -248,28 +255,19 @@ func filelist(h FileLister, r *Request) (responsePacket, error) {
finfo := make([]os.FileInfo, MaxFilelist)
n, err := lister.ListAt(finfo, offset)
// ignore EOF as we only return it when there are no results
if err != nil && err != io.EOF {
return nil, err
}
finfo = finfo[:n] // avoid need for nil tests below

// no results
if n == 0 {
switch r.Method {
case "List":
return nil, io.EOF
case "Stat", "Readlink":
err = &os.PathError{Op: "readlink", Path: r.Filepath,
Err: syscall.ENOENT}
return nil, err
}
}

switch r.Method {
case "List":
pd := r.popPacket()
if err != nil && err != io.EOF {
return statusFromError(pd, err)
}
if n == 0 {
return statusFromError(pd, io.EOF)
}
dirname := filepath.ToSlash(path.Base(r.Filepath))
ret := &sshFxpNamePacket{ID: pd.id}
ret := &sshFxpNamePacket{ID: pd.id()}

for _, fi := range finfo {
ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
Expand All @@ -278,13 +276,29 @@ func filelist(h FileLister, r *Request) (responsePacket, error) {
Attrs: []interface{}{fi},
})
}
return ret, nil
return ret
case "Stat":
if err != nil && err != io.EOF {
return statusFromError(r, err)
}
if n == 0 {
err = &os.PathError{Op: "stat", Path: r.Filepath,
Err: syscall.ENOENT}
return statusFromError(r, err)
}
return &sshFxpStatResponse{
ID: r.pkt_id,
info: finfo[0],
}, nil
}
case "Readlink":
if err != nil && err != io.EOF {
return statusFromError(r, err)
}
if n == 0 {
err = &os.PathError{Op: "readlink", Path: r.Filepath,
Err: syscall.ENOENT}
return statusFromError(r, err)
}
filename := finfo[0].Name()
return &sshFxpNamePacket{
ID: r.pkt_id,
Expand All @@ -293,15 +307,16 @@ func filelist(h FileLister, r *Request) (responsePacket, error) {
LongName: filename,
Attrs: emptyFileStat,
}},
}, nil
}
default:
return nil, errors.Errorf("unexpected method: %s", r.Method)
err = errors.Errorf("unexpected method: %s", r.Method)
return statusFromError(r, err)
}
}

// file data for additional read/write packets
func (r *Request) update(p hasHandle) error {
pd := packet_data{id: p.id()}
pd := packet_data{_id: p.id()}
switch p := p.(type) {
case *sshFxpReadPacket:
r.Method = "Get"
Expand Down
27 changes: 10 additions & 17 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func testRequest(method string) Request {
stateLock: &sync.RWMutex{},
}
for _, p := range []packet_data{
{id: 1, data: filecontents[:5], length: 5},
{id: 2, data: filecontents[5:], length: 5, offset: 5}} {
{_id: 1, data: filecontents[:5], length: 5},
{_id: 2, data: filecontents[5:], length: 5, offset: 5}} {
request.packets <- p
}
return request
Expand Down Expand Up @@ -122,8 +122,7 @@ func TestRequestGet(t *testing.T) {
request := testRequest("Get")
// req.length is 5, so we test reads in 5 byte chunks
for i, txt := range []string{"file-", "data."} {
pkt, err := request.handle(handlers)
assert.Nil(t, err)
pkt := request.handle(handlers)
dpkt := pkt.(*sshFxpDataPacket)
assert.Equal(t, dpkt.id(), uint32(i+1))
assert.Equal(t, string(dpkt.Data), txt)
Expand All @@ -133,35 +132,30 @@ func TestRequestGet(t *testing.T) {
func TestRequestPut(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Put")
pkt, err := request.handle(handlers)
assert.Nil(t, err)
pkt := request.handle(handlers)
statusOk(t, pkt)
pkt, err = request.handle(handlers)
assert.Nil(t, err)
pkt = request.handle(handlers)
statusOk(t, pkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}

func TestRequestCmdr(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Mkdir")
pkt, err := request.handle(handlers)
assert.Nil(t, err)
pkt := request.handle(handlers)
statusOk(t, pkt)

handlers.returnError()
pkt, err = request.handle(handlers)
assert.Nil(t, pkt)
assert.Equal(t, err, errTest)
pkt = request.handle(handlers)
assert.Equal(t, pkt, statusFromError(pkt, errTest))
}

func TestRequestInfoList(t *testing.T) { testInfoMethod(t, "List") }
func TestRequestInfoReadlink(t *testing.T) { testInfoMethod(t, "Readlink") }
func TestRequestInfoStat(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
pkt, err := request.handle(handlers)
assert.Nil(t, err)
pkt := request.handle(handlers)
spkt, ok := pkt.(*sshFxpStatResponse)
assert.True(t, ok)
assert.Equal(t, spkt.info.Name(), "request_test.go")
Expand All @@ -170,8 +164,7 @@ func TestRequestInfoStat(t *testing.T) {
func testInfoMethod(t *testing.T, method string) {
handlers := newTestHandlers()
request := testRequest(method)
pkt, err := request.handle(handlers)
assert.Nil(t, err)
pkt := request.handle(handlers)
npkt, ok := pkt.(*sshFxpNamePacket)
assert.True(t, ok)
assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
Expand Down

0 comments on commit cffd2fa

Please sign in to comment.