Skip to content

Commit

Permalink
agent forwarding support (gliderlabs#31)
Browse files Browse the repository at this point in the history
* agent: added agent forwarding support with an example
* context: encode session id to hex string
* agent: ensure conn doesn't change in closure as loop iterates
* tests: use HostKeyCallback in ClientConfig
* README: noting examples in _example
* agent: documented exported names, added constants for temp file creation

Signed-off-by: Jeff Lindsay <[email protected]>
  • Loading branch information
progrium authored Apr 14, 2017
1 parent 9b56478 commit 1051a0d
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 29 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,28 @@ building SSH servers. The goal of the API was to make it as simple as using

```
package main
import (
"github.com/gliderlabs/ssh"
"io"
"log"
)
func main() {
ssh.Handle(func(s ssh.Session) {
io.WriteString(s, "Hello world\n")
})
log.Fatal(ssh.ListenAndServe(":2222", nil))
}
```

This package was built after working on nearly a dozen projects using SSH and
collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)).
This package was built after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)).

## Examples

A bunch of great examples are in the `_example` directory.

## Usage

Expand Down
35 changes: 35 additions & 0 deletions _example/ssh-forwardagent/forwardagent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package main

import (
"fmt"
"log"
"os/exec"

"github.com/gliderlabs/ssh"
)

func main() {
ssh.Handle(func(s ssh.Session) {
cmd := exec.Command("ssh-add", "-l")
if ssh.AgentRequested(s) {
l, err := ssh.NewAgentListener()
if err != nil {
log.Fatal(err)
}
defer l.Close()
go ssh.ForwardAgentConnections(l, s)
cmd.Env = append(s.Environ(), fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String()))
} else {
cmd.Env = s.Environ()
}
cmd.Stdout = s
cmd.Stderr = s.Stderr()
if err := cmd.Run(); err != nil {
log.Println(err)
return
}
})

log.Println("starting ssh server on port 2222...")
log.Fatal(ssh.ListenAndServe(":2222", nil))
}
81 changes: 81 additions & 0 deletions agent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package ssh

import (
"io"
"io/ioutil"
"net"
"path"
"sync"

gossh "golang.org/x/crypto/ssh"
)

const (
agentRequestType = "[email protected]"
agentChannelType = "[email protected]"

agentTempDir = "auth-agent"
agentListenFile = "listener.sock"
)

// contextKeyAgentRequest is an internal context key for storing if the
// client requested agent forwarding
var contextKeyAgentRequest = &contextKey{"auth-agent-req"}

func setAgentRequested(sess *session) {
sess.ctx.SetValue(contextKeyAgentRequest, true)
}

// AgentRequested returns true if the client requested agent forwarding.
func AgentRequested(sess Session) bool {
return sess.Context().Value(contextKeyAgentRequest) == true
}

// NewAgentListener sets up a temporary Unix socket that can be communicated
// to the session environment and used for forwarding connections.
func NewAgentListener() (net.Listener, error) {
dir, err := ioutil.TempDir("", agentTempDir)
if err != nil {
return nil, err
}
l, err := net.Listen("unix", path.Join(dir, agentListenFile))
if err != nil {
return nil, err
}
return l, nil
}

// ForwardAgentConnections takes connections from a listener to proxy into the
// session on the OpenSSH channel for agent connections. It blocks and services
// connections until the listener stop accepting.
func ForwardAgentConnections(l net.Listener, s Session) {
sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn)
for {
conn, err := l.Accept()
if err != nil {
return
}
go func(conn net.Conn) {
defer conn.Close()
channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil)
if err != nil {
return
}
defer channel.Close()
go gossh.DiscardRequests(reqs)
var wg sync.WaitGroup
wg.Add(2)
go func() {
io.Copy(conn, channel)
conn.(*net.UnixConn).CloseWrite()
wg.Done()
}()
go func() {
io.Copy(channel, conn)
channel.CloseWrite()
wg.Done()
}()
wg.Wait()
}(conn)
}
}
7 changes: 6 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ssh
import (
"context"
"net"
"encoding/hex"

gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -46,6 +47,10 @@ var (
// The associated value will be of type *Server.
ContextKeyServer = &contextKey{"ssh-server"}

// ContextKeyConn is a context key for use with Contexts in this package.
// The associated value will be of type gossh.Conn.
ContextKeyConn = &contextKey{"ssh-conn"}

// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}
Expand Down Expand Up @@ -101,7 +106,7 @@ func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) {
if ctx.Value(ContextKeySessionID) != nil {
return
}
ctx.SetValue(ContextKeySessionID, string(conn.SessionID()))
ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID()))
ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion()))
ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion()))
ctx.SetValue(ContextKeyUser, conn.User())
Expand Down
2 changes: 2 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestPasswordAuth(t *testing.T) {
Auth: []gossh.AuthMethod{
gossh.Password(testPass),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}, PasswordAuth(func(ctx Context, password string) bool {
if ctx.User() != testUser {
t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
Expand Down Expand Up @@ -57,6 +58,7 @@ func TestPasswordAuthBadPass(t *testing.T) {
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
})
if err != nil {
if !strings.Contains(err.Error(), "unable to authenticate") {
Expand Down
35 changes: 12 additions & 23 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ type Server struct {
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil

channelHandlers map[string]channelHandler
}

// internal for now
type channelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext)

func (srv *Server) ensureHostSigner() error {
if len(srv.HostSigners) == 0 {
signer, err := generateSigner()
Expand All @@ -34,6 +39,9 @@ func (srv *Server) ensureHostSigner() error {
}

func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig {
srv.channelHandlers = map[string]channelHandler{
"session": sessionHandler,
}
config := &gossh.ServerConfig{}
for _, signer := range srv.HostSigners {
config.AddHostKey(signer)
Expand Down Expand Up @@ -114,36 +122,17 @@ func (srv *Server) handleConn(conn net.Conn) {
// TODO: trigger event callback
return
}
ctx.SetValue(ContextKeyConn, sshConn)
ctx.applyConnMetadata(sshConn)
go gossh.DiscardRequests(reqs)
for ch := range chans {
if ch.ChannelType() != "session" {
handler, found := srv.channelHandlers[ch.ChannelType()]
if !found {
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
continue
}
go srv.handleChannel(sshConn, ch, ctx)
}
}

func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
ch, reqs, err := newChan.Accept()
if err != nil {
// TODO: trigger event callback
return
}
sess := srv.newSession(conn, ch, ctx)
sess.handleRequests(reqs)
}

func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel, ctx *sshContext) *session {
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
go handler(srv, sshConn, ch, ctx)
}
return sess
}

// ListenAndServe listens on the TCP network address srv.Addr and then calls
Expand Down
22 changes: 22 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ type Session interface {
// TODO: Signals(c chan<- Signal)
}

func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
ch, reqs, err := newChan.Accept()
if err != nil {
// TODO: trigger event callback
return
}
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
}
sess.handleRequests(reqs)
}

type session struct {
gossh.Channel
conn *gossh.ServerConn
Expand Down Expand Up @@ -205,6 +221,12 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
sess.winch <- win
}
req.Reply(ok, nil)
case agentRequestType:
// TODO: option/callback to allow agent forwarding
setAgentRequested(sess)
req.Reply(true, nil)
default:
// TODO: debug log
}
}
}
3 changes: 3 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g
},
}
}
if config.HostKeyCallback == nil {
config.HostKeyCallback = gossh.InsecureIgnoreHostKey()
}
client, err := gossh.Dial("tcp", addr, config)
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 1051a0d

Please sign in to comment.