Skip to content

Commit

Permalink
Support SSH port forwarding ( local forward ) for SSH Runner
Browse files Browse the repository at this point in the history
  • Loading branch information
k1LoW committed Feb 10, 2023
1 parent 70bb195 commit 1e6a94a
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ runners:
# host: myserver
# sshConfig: path/to/ssh_config
# keepSession: false
# localForward: '33306:127.0.0.1:3306'
```

See [testdata/book/sshd.yml](testdata/book/sshd.yml).
Expand Down
18 changes: 15 additions & 3 deletions book.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,27 @@ func (bk *book) parseSSHRunnerWithDetailed(name string, b []byte) (bool, error)
}
opts = append(opts, sshc.IdentityFile(p))
}
var lf *sshLocalForward
if c.LocalForward != "" {
if strings.Count(c.LocalForward, ":") != 2 {
return false, fmt.Errorf("invalid SSH runner: '%s': invalid localForward option: %s", name, c.LocalForward)
}
splitted := strings.SplitN(c.LocalForward, ":", 2)
lf = &sshLocalForward{
local: fmt.Sprintf("127.0.0.1:%s", splitted[0]),
remote: splitted[1],
}
}

client, err := sshc.NewClient(host, opts...)
if err != nil {
return false, err
}
r := &sshRunner{
name: name,
client: client,
keepSession: c.KeepSession,
name: name,
client: client,
keepSession: c.KeepSession,
localForward: lf,
}

if r.keepSession {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ require (
github.com/golang/protobuf v1.5.2
github.com/google/go-cmp v0.5.9
github.com/jhump/protoreflect v1.14.1
github.com/juliangruber/go-intersect v1.1.0
github.com/k1LoW/curlreq v0.3.2
github.com/k1LoW/duration v1.2.0
github.com/k1LoW/exec v0.2.0
Expand All @@ -31,7 +32,7 @@ require (
github.com/k1LoW/grpcurlreq v0.1.0
github.com/k1LoW/httpstub v0.5.0
github.com/k1LoW/repin v0.3.4
github.com/k1LoW/sshc/v3 v3.0.1
github.com/k1LoW/sshc/v3 v3.1.0
github.com/k1LoW/stopw v0.7.1
github.com/k1LoW/urlfilepath v0.1.0
github.com/lestrrat-go/backoff/v2 v2.0.8
Expand Down Expand Up @@ -86,7 +87,6 @@ require (
github.com/josharian/intern v1.0.0 // indirect
github.com/josharian/mapfs v0.0.0-20210615234106-095c008854e6 // indirect
github.com/josharian/txtarfs v0.0.0-20210615234325-77aca6df5bca // indirect
github.com/juliangruber/go-intersect v1.1.0 // indirect
github.com/k1LoW/go-github-client/v48 v48.2.3 // indirect
github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
Expand Down
5 changes: 3 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de h1:FxWPpzIjnTlhP
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de/go.mod h1:DCaWoUhZrYW9p1lxo/cm8EmUOOzAPSEZNGF2DK1dJgw=
github.com/bmatcuk/doublestar/v4 v4.6.0 h1:HTuxyug8GyFbRkrffIpzNCSK4luc0TY3wzXvzIZhEXc=
github.com/bmatcuk/doublestar/v4 v4.6.0/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/buildkite/interpolate v0.0.0-20200526001904-07f35b4ae251 h1:k6UDF1uPYOs0iy1HPeotNa155qXRWrzKnqAaGXHLZCE=
github.com/buildkite/interpolate v0.0.0-20200526001904-07f35b4ae251/go.mod h1:gbPR1gPu9dB96mucYIR7T3B7p/78hRVSOuzIWLHK2Y4=
Expand Down Expand Up @@ -258,8 +259,8 @@ github.com/k1LoW/httpstub v0.5.0 h1:dLzScTNj0uP0LeBWASCNUteKXtl5kgwebkntD0avCu4=
github.com/k1LoW/httpstub v0.5.0/go.mod h1:chz4+4x3yYdVmYs24D7S4XHDP4zeaJvTWfFzmCtFuMs=
github.com/k1LoW/repin v0.3.4 h1:xcNuBBc/ISHUNBzjXNTCux4OYZND5ZMiyz4SrRtpDhg=
github.com/k1LoW/repin v0.3.4/go.mod h1:1abQMGdYFegTCsxbhZ3O5P8aKENYS37UhQKwKYKiYkg=
github.com/k1LoW/sshc/v3 v3.0.1 h1:QkyyrTp96Lf2kgXpo8+BWC25BiMTBGRD1z230bBYk3E=
github.com/k1LoW/sshc/v3 v3.0.1/go.mod h1:mZgv89TRnWw90AFeDgso/rL9CS1pE9DG3P5JhPTMdXk=
github.com/k1LoW/sshc/v3 v3.1.0 h1:vzy55KbN9uT+IXLH73Tb4vyzy+iLcxe61bSQpL7zhzs=
github.com/k1LoW/sshc/v3 v3.1.0/go.mod h1:mZgv89TRnWw90AFeDgso/rL9CS1pE9DG3P5JhPTMdXk=
github.com/k1LoW/stopw v0.7.1 h1:yQtO45xgxys8AMxz7kOoxOoUNyIcId/DkMf0bOvDNqE=
github.com/k1LoW/stopw v0.7.1/go.mod h1:YzYHAs+C2G8O46iozN8yfC4K/SeuB73C3VQDBqN5+9Y=
github.com/k1LoW/urlfilepath v0.1.0 h1:JU2FJISuw9oGHy0SAC85O85pnYS3/Z2r0TLlIpy215E=
Expand Down
18 changes: 15 additions & 3 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,16 +469,28 @@ func SSHRunnerWithOptions(name string, opts ...sshRunnerOption) Option {
}
opts = append(opts, sshc.IdentityFile(p))
}
var lf *sshLocalForward
if c.LocalForward != "" {
if strings.Count(c.LocalForward, ":") != 2 {
return fmt.Errorf("invalid SSH runner: '%s': invalid localForward option: %s", name, c.LocalForward)
}
splitted := strings.SplitN(c.LocalForward, ":", 2)
lf = &sshLocalForward{
local: fmt.Sprintf("127.0.0.1:%s", splitted[0]),
remote: splitted[1],
}
}

client, err := sshc.NewClient(host, opts...)
if err != nil {
return err
}

r := &sshRunner{
name: name,
client: client,
keepSession: c.KeepSession,
name: name,
client: client,
keepSession: c.KeepSession,
localForward: lf,
}

if r.keepSession {
Expand Down
9 changes: 9 additions & 0 deletions runner_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type sshRunnerConfig struct {
Port int `yaml:"port,omitempty"`
IdentityFile string `yaml:"identityFile,omitempty"`
KeepSession bool `yaml:"keepSession,omitempty"`
LocalForward string `yaml:"localForward,omitempty"`
}

type httpRunnerOption func(*httpRunnerConfig) error
Expand Down Expand Up @@ -223,3 +224,11 @@ func KeepSession(enable bool) sshRunnerOption {
return nil
}
}

func LocalForward(l string) sshRunnerOption {
return func(c *sshRunnerConfig) error {
c.LocalForward = l
c.KeepSession = true
return nil
}
}
93 changes: 84 additions & 9 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,38 @@ import (
"errors"
"fmt"
"io"
"log"
"net"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/k1LoW/sshc/v3"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)

const sshOutTimeout = 1 * time.Second

type sshRunner struct {
name string
addr string
client *ssh.Client
sess *ssh.Session
stdin io.WriteCloser
stdout chan string
stderr chan string
keepSession bool
operator *operator
name string
addr string
client *ssh.Client
sess *ssh.Session
stdin io.WriteCloser
stdout chan string
stderr chan string
keepSession bool
localForward *sshLocalForward
sessCancel context.CancelFunc
operator *operator
}

type sshLocalForward struct {
local string
remote string
}

type sshCommand struct {
Expand Down Expand Up @@ -75,6 +86,7 @@ func (rnr *sshRunner) startSession() error {
if !rnr.keepSession {
return errors.New("could not use startSession() when keepSession = false")
}
ctx, cancel := context.WithCancel(context.Background())

sess, err := rnr.client.NewSession()
if err != nil {
Expand Down Expand Up @@ -120,10 +132,36 @@ func (rnr *sshRunner) startSession() error {
close(el)
}()

// local forward
if rnr.localForward != nil {
// remote
remote, err := rnr.client.Dial("tcp", rnr.localForward.remote)
if err != nil {
return err
}
local, err := net.Listen("tcp", rnr.localForward.local)
if err != nil {
return err
}

go func() {
for {
lc, err := local.Accept()
if err != nil {
log.Println(err)
}
if err := handleConns(ctx, lc, remote); err != nil {
log.Println(err)
}
}
}()
}

rnr.sess = sess
rnr.stdin = stdin
rnr.stdout = ol
rnr.stderr = el
rnr.sessCancel = cancel
return nil
}

Expand All @@ -132,10 +170,14 @@ func (rnr *sshRunner) closeSession() error {
return nil
}
rnr.sess.Close()
if rnr.sessCancel != nil {
rnr.sessCancel()
}
rnr.sess = nil
rnr.stdin = nil
rnr.stdout = nil
rnr.stderr = nil
rnr.sessCancel = nil
return nil
}

Expand Down Expand Up @@ -215,3 +257,36 @@ func (rnr *sshRunner) runOnce(ctx context.Context, c *sshCommand) error {

return nil
}

func handleConns(ctx context.Context, lc net.Conn, remote net.Conn) error {
defer lc.Close()
var wg sync.WaitGroup
eg, _ := errgroup.WithContext(ctx)
wg.Add(1)

// remote -> local
eg.Go(func() error {
_, err := io.Copy(lc, remote)
if err != nil {
return err
}
wg.Done()
return nil
})

// local -> remote
eg.Go(func() error {
_, err := io.Copy(remote, lc)
if err != nil {
return err
}
wg.Done()
return nil
})

wg.Wait()
if err := eg.Wait(); err != nil {
return err
}
return nil
}

0 comments on commit 1e6a94a

Please sign in to comment.