Skip to content

Commit

Permalink
Close port forward writer on reader
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbrose committed Jan 20, 2023
1 parent ba27e5b commit 2b95cbc
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 41 deletions.
16 changes: 16 additions & 0 deletions internal/codespaces/codespaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"time"

"github.com/cenkalti/backoff/v4"
Expand Down Expand Up @@ -79,3 +80,18 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
Logger: sessionLogger,
})
}

// ListenTCP starts a localhost tcp listener and returns the listener and bound port
func ListenTCP(port int) (*net.TCPListener, int, error) {
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return nil, 0, fmt.Errorf("failed to build tcp address: %w", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, 0, fmt.Errorf("failed to listen to local port over tcp: %w", err)
}
port = listener.Addr().(*net.TCPAddr).Port

return listener, port, nil
}
19 changes: 16 additions & 3 deletions internal/codespaces/rpc/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv

// Finds a free port to listen on and creates a new RPC invoker that connects to that port
func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
listener, err := listenTCP()
if err != nil {
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
return nil, err
}
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
localAddress := listener.Addr().String()

invoker := &invoker{
session: session,
Expand Down Expand Up @@ -229,3 +229,16 @@ func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSS

return port, response.User, nil
}

func listenTCP() (*net.TCPListener, error) {
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("failed to build tcp address: %w", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
}

return listener, nil
}
4 changes: 1 addition & 3 deletions internal/codespaces/states.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"io"
"log"
"net"
"time"

"github.com/cli/cli/v2/internal/codespaces/api"
Expand Down Expand Up @@ -53,11 +52,10 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
}()

// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port
listen, localPort, err := ListenTCP(0)
if err != nil {
return err
}
localPort := listen.Addr().(*net.TCPAddr).Port

progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := rpc.CreateInvoker(ctx, session)
Expand Down
3 changes: 2 additions & 1 deletion pkg/cmd/codespace/jupyter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"strings"

"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -58,7 +59,7 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
a.StopProgressIndicator()

// Pass 0 to pick a random port
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
listen, _, err := codespaces.ListenTCP(0)
if err != nil {
return err
}
Expand Down
4 changes: 1 addition & 3 deletions pkg/cmd/codespace/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package codespace
import (
"context"
"fmt"
"net"

"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/rpc"
Expand Down Expand Up @@ -49,12 +48,11 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
defer safeClose(session, &err)

// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := net.Listen("tcp", "127.0.0.1:0") // arbitrary port
listen, localPort, err := codespaces.ListenTCP(0)
if err != nil {
return err
}
defer listen.Close()
localPort := listen.Addr().(*net.TCPAddr).Port

a.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := rpc.CreateInvoker(ctx, session)
Expand Down
3 changes: 1 addition & 2 deletions pkg/cmd/codespace/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -390,7 +389,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st
for _, pair := range portPairs {
pair := pair
group.Go(func() error {
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
listen, _, err := codespaces.ListenTCP(pair.local)
if err != nil {
return err
}
Expand Down
25 changes: 2 additions & 23 deletions pkg/cmd/codespace/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"path"
Expand Down Expand Up @@ -188,7 +186,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e

if opts.stdio {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
stdio := newReadWriteCloser(os.Stdin, os.Stdout)
stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout)
err := fwd.Forward(ctx, stdio) // always non-nil
return fmt.Errorf("tunnel closed: %w", err)
}
Expand All @@ -199,12 +197,11 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
// Ensure local port is listening before client (Shell) connects.
// Unless the user specifies a server port, localSSHServerPort is 0
// and thus the client will pick a random port.
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localSSHServerPort))
listen, localSSHServerPort, err := codespaces.ListenTCP(localSSHServerPort)
if err != nil {
return err
}
defer listen.Close()
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port

connectDestination := opts.profile
if connectDestination == "" {
Expand Down Expand Up @@ -745,21 +742,3 @@ func (fl *fileLogger) Name() string {
func (fl *fileLogger) Close() error {
return fl.f.Close()
}

type combinedReadWriteCloser struct {
io.ReadCloser
io.WriteCloser
}

func newReadWriteCloser(reader io.ReadCloser, writer io.WriteCloser) io.ReadWriteCloser {
return &combinedReadWriteCloser{reader, writer}
}

func (crwc *combinedReadWriteCloser) Close() error {
werr := crwc.WriteCloser.Close()
rerr := crwc.ReadCloser.Close()
if werr != nil {
return werr
}
return rerr
}
39 changes: 34 additions & 5 deletions pkg/liveshare/port_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,33 @@ type portForwardingSession interface {
KeepAlive(string)
}

type ReadWriteHalfCloser interface {
io.ReadWriteCloser
CloseWrite() error
}

type combinedReadWriteHalfCloser struct {
io.ReadCloser
io.WriteCloser
}

func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser {
return &combinedReadWriteHalfCloser{reader, writer}
}

func (crwc *combinedReadWriteHalfCloser) Close() error {
werr := crwc.WriteCloser.Close()
rerr := crwc.ReadCloser.Close()
if werr != nil {
return werr
}
return rerr
}

func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
return crwc.WriteCloser.Close()
}

// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
Expand Down Expand Up @@ -48,7 +75,7 @@ func NewPortForwarder(session portForwardingSession, name string, remotePort int
// until it encounters the first error, which may include context
// cancellation. Its error result is always non-nil. The caller is
// responsible for closing the listening port.
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
Expand All @@ -65,7 +92,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List
}
go func() {
for {
conn, err := listen.Accept()
conn, err := listen.AcceptTCP()
if err != nil {
sendError(err)
return
Expand All @@ -84,7 +111,7 @@ func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.List

// Forward forwards traffic between the container's remote port and
// the specified read/write stream. On return, the stream is closed.
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
conn.Close()
Expand Down Expand Up @@ -143,7 +170,7 @@ func (t *trafficMonitor) Read(p []byte) (n int, err error) {
}

// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) {
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
defer span.Finish()

Expand All @@ -165,9 +192,11 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, co

// bi-directional copy of data.
errs := make(chan error, 2)
copyConn := func(w io.Writer, r io.Reader) {
copyConn := func(w ReadWriteHalfCloser, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err

w.CloseWrite()
}

var (
Expand Down
6 changes: 5 additions & 1 deletion pkg/liveshare/port_forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func TestPortForwarderStart(t *testing.T) {
t.Fatal(err)
}
defer listen.Close()
tcpListener, ok := listen.(*net.TCPListener)
if !ok {
t.Fatal("net.Listen did not return a TCPListener")
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -82,7 +86,7 @@ func TestPortForwarderStart(t *testing.T) {

done := make(chan error, 2)
go func() {
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, listen)
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener)
}()

go func() {
Expand Down

0 comments on commit 2b95cbc

Please sign in to comment.