Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dmgardiner25 committed Jan 4, 2023
1 parent 731ba68 commit 000a84d
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 56 deletions.
14 changes: 0 additions & 14 deletions internal/codespaces/codespaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
)

Expand Down Expand Up @@ -80,16 +79,3 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
Logger: sessionLogger,
})
}

// Helper function to connect to the internal RPC server and return an RPC invoker for it
func CreateRPCInvoker(ctx context.Context, session *liveshare.Session, token string) (*rpc.Invoker, error) {
ctx, cancel := context.WithTimeout(ctx, rpc.ConnectionTimeout)
defer cancel()

invoker, err := rpc.Connect(ctx, session, token)
if err != nil {
return nil, fmt.Errorf("error connecting to internal server: %w", err)
}

return invoker, nil
}
67 changes: 38 additions & 29 deletions internal/codespaces/rpc/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (

"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/pkg/liveshare"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
Expand All @@ -28,35 +27,45 @@ const (
codespacesInternalSessionName = "CodespacesInternal"
)

type liveshareSession interface {
type Invoker interface {
Close() error
GetSharedServers(context.Context) ([]*liveshare.Port, error)
KeepAlive(string)
OpenStreamingChannel(context.Context, liveshare.ChannelID) (ssh.Channel, error)
StartSharing(context.Context, string, int) (liveshare.ChannelID, error)
StartSSHServer(context.Context) (int, string, error)
StartSSHServerWithOptions(context.Context, liveshare.StartSSHServerOptions) (int, string, error)
RebuildContainer(context.Context, bool) error
StartJupyterServer(ctx context.Context) (int, string, error)
RebuildContainer(ctx context.Context, full bool) error
StartSSHServer(ctx context.Context) (int, string, error)
StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error)
}

type Invoker struct {
type invoker struct {
conn *grpc.ClientConn
token string
session liveshareSession
session liveshare.LiveshareSession
listener net.Listener
jupyterClient jupyter.JupyterServerHostClient
cancelPF context.CancelFunc
}

// Finds a free port to listen on and creates a new gRPC client that connects to that port
func Connect(ctx context.Context, session liveshareSession, token string) (*Invoker, error) {
// Connects to the internal RPC server and returns a new invoker for it
func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession, token string) (Invoker, error) {
ctx, cancel := context.WithTimeout(ctx, ConnectionTimeout)
defer cancel()

invoker, err := connect(ctx, session, token)
if err != nil {
return nil, fmt.Errorf("error connecting to internal server: %w", err)
}

return invoker, nil
}

// Finds a free port to listen on and creates a new RPC invoker that connects to that port
func connect(ctx context.Context, session liveshare.LiveshareSession, token string) (Invoker, error) {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
if err != nil {
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
}
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)

invoker := &Invoker{
invoker := &invoker{
token: token,
session: session,
listener: listener,
Expand Down Expand Up @@ -113,30 +122,30 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Invo
}

// Closes the gRPC connection
func (g *Invoker) Close() error {
g.cancelPF()
func (i *invoker) Close() error {
i.cancelPF()

// Closing the local listener effectively closes the gRPC connection
if err := g.listener.Close(); err != nil {
g.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error
if err := i.listener.Close(); err != nil {
i.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error
return fmt.Errorf("failed to close local tcp port listener: %w", err)
}

return nil
}

// Appends the authentication token to the gRPC context
func (g *Invoker) appendMetadata(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, "Authorization", "Bearer "+g.token)
func (i *invoker) appendMetadata(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, "Authorization", "Bearer "+i.token)
}

// Starts a remote JupyterLab server to allow the user to connect to the codespace via JupyterLab in their browser
func (g *Invoker) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) {
ctx = g.appendMetadata(ctx)
func (i *invoker) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) {
ctx = i.appendMetadata(ctx)
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()

response, err := g.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{})
response, err := i.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{})
if err != nil {
return 0, "", fmt.Errorf("failed to invoke JupyterLab RPC: %w", err)
}
Expand All @@ -154,16 +163,16 @@ func (g *Invoker) StartJupyterServer(ctx context.Context) (port int, serverUrl s
}

// Rebuilds the container using cached layers by default or from scratch if full is true
func (g *Invoker) RebuildContainer(ctx context.Context, full bool) error {
return g.session.RebuildContainer(ctx, full)
func (i *invoker) RebuildContainer(ctx context.Context, full bool) error {
return i.session.RebuildContainer(ctx, full)
}

// Starts a remote SSH server to allow the user to connect to the codespace via SSH
func (g *Invoker) StartSSHServer(ctx context.Context) (int, string, error) {
return g.session.StartSSHServer(ctx)
func (i *invoker) StartSSHServer(ctx context.Context) (int, string, error) {
return i.session.StartSSHServer(ctx)
}

// Starts a remote SSH server to allow the user to connect to the codespace via SSH
func (g *Invoker) StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) {
return g.session.StartSSHServerWithOptions(ctx, options)
func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) {
return i.session.StartSSHServerWithOptions(ctx, options)
}
9 changes: 4 additions & 5 deletions internal/codespaces/rpc/invoker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func startServer(t *testing.T) {
})
}

func connect(t *testing.T) (invoker *Invoker) {
func createTestInvoker(t *testing.T) Invoker {
t.Helper()

invoker, err := Connect(context.Background(), &rpctest.Session{}, "token")
invoker, err := CreateInvoker(context.Background(), &rpctest.Session{}, "token") //connect(context.Background(), &rpctest.Session{}, "token")
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
Expand All @@ -50,8 +50,7 @@ func connect(t *testing.T) (invoker *Invoker) {
// Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully
func TestStartJupyterServerSuccess(t *testing.T) {
startServer(t)
invoker := connect(t)

invoker := createTestInvoker(t)
port, url, err := invoker.StartJupyterServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
Expand All @@ -67,7 +66,7 @@ func TestStartJupyterServerSuccess(t *testing.T) {
// Test that the RPC invoker returns an error when the JupyterLab server fails to start
func TestStartJupyterServerFailure(t *testing.T) {
startServer(t)
invoker := connect(t)
invoker := createTestInvoker(t)
rpctest.JupyterMessage = "error message"
rpctest.JupyterResult = false
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage)
Expand Down
3 changes: 2 additions & 1 deletion internal/codespaces/states.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/liveshare"
)
Expand Down Expand Up @@ -59,7 +60,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
localPort := listen.Addr().(*net.TCPAddr).Port

progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := CreateRPCInvoker(ctx, session, "")
invoker, err := rpc.CreateInvoker(ctx, session, "")
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/cmd/codespace/jupyter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net"
"strings"

"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -45,7 +45,7 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
defer safeClose(session, &err)

a.StartProgressIndicatorWithLabel("Starting JupyterLab on codespace")
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
invoker, err := rpc.CreateInvoker(ctx, session, "")
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/cmd/codespace/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"

"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -56,7 +57,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
localPort := listen.Addr().(*net.TCPAddr).Port

a.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
invoker, err := rpc.CreateInvoker(ctx, session, "")
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/cmd/codespace/rebuild.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"
"fmt"

"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -52,7 +52,7 @@ func (a *App) Rebuild(ctx context.Context, codespaceName string, full bool) (err
}
defer safeClose(session, &err)

invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
invoker, err := rpc.CreateInvoker(ctx, session, "")
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/cmd/codespace/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/liveshare"
Expand Down Expand Up @@ -173,7 +174,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
defer safeClose(session, &err)

a.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
invoker, err := rpc.CreateInvoker(ctx, session, "")
if err != nil {
return err
}
Expand Down Expand Up @@ -514,7 +515,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
} else {
defer safeClose(session, &err)

invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
invoker, err := rpc.CreateInvoker(ctx, session, "")
if err != nil {
result.err = fmt.Errorf("error connecting to codespace: %w", err)
} else {
Expand Down
12 changes: 12 additions & 0 deletions pkg/liveshare/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ type ChannelID struct {
name, condition string
}

// Interface to allow the mocking of the liveshare session
type LiveshareSession interface {
Close() error
GetSharedServers(context.Context) ([]*Port, error)
KeepAlive(string)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
StartSharing(context.Context, string, int) (ChannelID, error)
StartSSHServer(context.Context) (int, string, error)
StartSSHServerWithOptions(context.Context, StartSSHServerOptions) (int, string, error)
RebuildContainer(context.Context, bool) error
}

// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
Expand Down

0 comments on commit 000a84d

Please sign in to comment.