Skip to content

Commit

Permalink
Rework virt-api dialers to be more flexible
Browse files Browse the repository at this point in the history
Allow virt-api websocket dialers to return the virt-handler websocket
connection itself or the underlying connection based on the need.

Signed-off-by: Roman Mohr <[email protected]>
rmohr committed Sep 20, 2022
1 parent 5f4941b commit 74165c2
Showing 5 changed files with 160 additions and 79 deletions.
100 changes: 63 additions & 37 deletions pkg/virt-api/rest/dialers.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@ import (
"fmt"
"net"

"github.com/gorilla/websocket"

restful "github.com/emicklei/go-restful"
"k8s.io/apimachinery/pkg/api/errors"

@@ -14,47 +16,71 @@ import (
"kubevirt.io/kubevirt/pkg/virt-api/definitions"
)

func netDialer(request *restful.Request) dialer {
return func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
logger := log.Log.Object(vmi)

targetIP, err := getTargetInterfaceIP(vmi)
if err != nil {
logger.Reason(err).Error("Can't establish TCP tunnel.")
return nil, errors.NewBadRequest(err.Error())
}

port := request.PathParameter(definitions.PortParamName)
if len(port) < 1 {
return nil, errors.NewBadRequest("port must not be empty")
}

protocol := "tcp"
if protocolParam := request.PathParameter(definitions.ProtocolParamName); len(protocolParam) > 0 {
protocol = protocolParam
}

addr := fmt.Sprintf("%s:%s", targetIP, port)
conn, err := net.Dial(protocol, addr)
if err != nil {
logger.Reason(err).Errorf("Can't dial %s %s", protocol, addr)
return nil, errors.NewInternalError(fmt.Errorf("dialing VM: %w", err))
}
return conn, nil
type netDial struct {
request *restful.Request
}

type handlerDial struct {
getURL URLResolver
app *SubresourceAPIApp
}

func (h handlerDial) Dial(vmi *v1.VirtualMachineInstance) (*websocket.Conn, *errors.StatusError) {
url, _, statusError := h.app.getVirtHandlerFor(vmi, h.getURL)
if statusError != nil {
return nil, statusError
}
conn, _, err := kubecli.Dial(url, h.app.handlerTLSConfiguration)
if err != nil {
return nil, errors.NewInternalError(fmt.Errorf("dialing virt-handler: %w", err))
}
return conn, nil
}

func (h handlerDial) DialUnderlying(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
conn, err := h.Dial(vmi)
if err != nil {
return nil, err
}
return conn.UnderlyingConn(), nil
}

func (n netDial) Dial(vmi *v1.VirtualMachineInstance) (*websocket.Conn, *errors.StatusError) {
panic("don't call me")
}

func (n netDial) DialUnderlying(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
logger := log.Log.Object(vmi)

targetIP, err := getTargetInterfaceIP(vmi)
if err != nil {
logger.Reason(err).Error("Can't establish TCP tunnel.")
return nil, errors.NewBadRequest(err.Error())
}

port := n.request.PathParameter(definitions.PortParamName)
if len(port) < 1 {
return nil, errors.NewBadRequest("port must not be empty")
}

protocol := "tcp"
if protocolParam := n.request.PathParameter(definitions.ProtocolParamName); len(protocolParam) > 0 {
protocol = protocolParam
}

addr := fmt.Sprintf("%s:%s", targetIP, port)
conn, err := net.Dial(protocol, addr)
if err != nil {
logger.Reason(err).Errorf("Can't dial %s %s", protocol, addr)
return nil, errors.NewInternalError(fmt.Errorf("dialing VM: %w", err))
}
return conn, nil
}

func (app *SubresourceAPIApp) virtHandlerDialer(getURL URLResolver) dialer {
return func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
url, _, statusError := app.getVirtHandlerFor(vmi, getURL)
if statusError != nil {
return nil, statusError
}
conn, _, err := kubecli.Dial(url, app.handlerTLSConfiguration)
if err != nil {
return nil, errors.NewInternalError(fmt.Errorf("dialing virt-handler: %w", err))
}
return conn.UnderlyingConn(), nil
return handlerDial{
getURL: getURL,
app: app,
}
}

2 changes: 1 addition & 1 deletion pkg/virt-api/rest/portforward.go
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ func (app *SubresourceAPIApp) PortForwardRequestHandler(fetcher vmiFetcher) rest
streamer := NewWebsocketStreamer(
fetcher,
validateVMIForPortForward,
netDialer(request),
netDial{request: request},
)

streamer.Handle(request, response)
79 changes: 52 additions & 27 deletions pkg/virt-api/rest/streamer.go
Original file line number Diff line number Diff line change
@@ -18,26 +18,32 @@ import (
)

type vmiFetcher func(namespace, name string) (*v1.VirtualMachineInstance, *errors.StatusError)
type dialer func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError)
type validator func(vmi *v1.VirtualMachineInstance) *errors.StatusError
type streamFunc func(clientConn *websocket.Conn, serverConn net.Conn, result chan<- streamFuncResult)
type streamFuncResult error

type dialer interface {
Dial(vmi *v1.VirtualMachineInstance) (*websocket.Conn, *errors.StatusError)
DialUnderlying(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError)
}

type Streamer struct {
fetchVMI vmiFetcher
validateVMI validator
dial dialer
dialer *DirectDialer
keepAliveClient func(ctx context.Context, conn *websocket.Conn, cancel func())

streamToClient streamFunc
streamToServer streamFunc
}

type DirectDialer struct {
fetchVMI vmiFetcher
validateVMI validator
dial dialer
}

func NewRawStreamer(fetch vmiFetcher, validate validator, dial dialer) *Streamer {
return &Streamer{
fetchVMI: fetch,
validateVMI: validate,
dial: dial,
dialer: NewDirectDialer(fetch, validate, dial),
streamToServer: func(clientConn *websocket.Conn, serverConn net.Conn, result chan<- streamFuncResult) {
_, err := io.Copy(serverConn, clientConn.UnderlyingConn())
result <- err
@@ -51,9 +57,7 @@ func NewRawStreamer(fetch vmiFetcher, validate validator, dial dialer) *Streamer

func NewWebsocketStreamer(fetch vmiFetcher, validate validator, dial dialer) *Streamer {
return &Streamer{
fetchVMI: fetch,
validateVMI: validate,
dial: dial,
dialer: NewDirectDialer(fetch, validate, dial),
keepAliveClient: keepAliveClientStream,
streamToServer: func(clientConn *websocket.Conn, serverConn net.Conn, result chan<- streamFuncResult) {
_, err := kubecli.CopyFrom(serverConn, clientConn)
@@ -69,18 +73,13 @@ func NewWebsocketStreamer(fetch vmiFetcher, validate validator, dial dialer) *St
func (s *Streamer) Handle(request *restful.Request, response *restful.Response) error {
namespace := request.PathParameter(definitions.NamespaceParamName)
name := request.PathParameter(definitions.NameParamName)
serverConn, statusErr := s.dialer.DialUnderlying(namespace, name)

vmi, statusErr := s.fetchAndValidateVMI(namespace, name)
if statusErr != nil {
writeError(statusErr, response)
return statusErr
}

serverConn, statusErr := s.dial(vmi)
if statusErr != nil {
writeError(statusErr, response)
return statusErr
}
clientConn, err := clientConnectionUpgrade(request, response)
if err != nil {
writeError(errors.NewBadRequest(err.Error()), response)
@@ -112,17 +111,6 @@ func (s *Streamer) Handle(request *restful.Request, response *restful.Response)
return result2
}

func (s *Streamer) fetchAndValidateVMI(namespace, name string) (*v1.VirtualMachineInstance, *errors.StatusError) {
vmi, err := s.fetchVMI(namespace, name)
if err != nil {
return nil, err
}
if err := s.validateVMI(vmi); err != nil {
return nil, err
}
return vmi, nil
}

const streamTimeout = 10 * time.Second

func clientConnectionUpgrade(request *restful.Request, response *restful.Response) (*websocket.Conn, error) {
@@ -165,3 +153,40 @@ func keepAliveClientStream(ctx context.Context, conn *websocket.Conn, cancel fun
}
}
}

func NewDirectDialer(fetch vmiFetcher, validate validator, dial dialer) *DirectDialer {
return &DirectDialer{
fetchVMI: fetch,
validateVMI: validate,
dial: dial,
}
}

func (d *DirectDialer) Dial(namespace, name string) (*websocket.Conn, *errors.StatusError) {
vmi, err := d.fetchAndValidateVMI(namespace, name)
if err != nil {
return nil, err
}

return d.dial.Dial(vmi)
}

func (d *DirectDialer) DialUnderlying(namespace, name string) (net.Conn, *errors.StatusError) {
vmi, err := d.fetchAndValidateVMI(namespace, name)
if err != nil {
return nil, err
}

return d.dial.DialUnderlying(vmi)
}

func (d *DirectDialer) fetchAndValidateVMI(namespace, name string) (*v1.VirtualMachineInstance, *errors.StatusError) {
vmi, err := d.fetchVMI(namespace, name)
if err != nil {
return nil, err
}
if err := d.validateVMI(vmi); err != nil {
return nil, err
}
return vmi, nil
}
51 changes: 37 additions & 14 deletions pkg/virt-api/rest/streamer_test.go
Original file line number Diff line number Diff line change
@@ -46,25 +46,31 @@ var _ = Describe("Streamer", func() {
dialCalled bool
streamToClientCalled chan struct{}
streamToServerCalled chan struct{}
directDialer *DirectDialer
)
BeforeEach(func() {
testVMI = &v1.VirtualMachineInstance{ObjectMeta: metav1.ObjectMeta{Name: "test-vmi"}}
streamToClientCalled = make(chan struct{}, 1)
streamToServerCalled = make(chan struct{}, 1)
serverConn, serverPipe = net.Pipe()
streamer = &Streamer{
fetchVMI: func(_, _ string) (*v1.VirtualMachineInstance, *errors.StatusError) {
directDialer = NewDirectDialer(
func(_, _ string) (*v1.VirtualMachineInstance, *errors.StatusError) {
fetchVMICalled = true
return testVMI, nil
},
validateVMI: func(vmi *v1.VirtualMachineInstance) *errors.StatusError {
func(vmi *v1.VirtualMachineInstance) *errors.StatusError {
validateVMICalled = true
return nil
},
dial: func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
dialCalled = true
return serverConn, nil
mockDialer{
dialUnderlying: func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
dialCalled = true
return serverConn, nil
},
},
)
streamer = &Streamer{
dialer: directDialer,
streamToClient: func(clientSocket *websocket.Conn, serverConn net.Conn, result chan<- streamFuncResult) {
result <- nil
streamToClientCalled <- struct{}{}
@@ -102,14 +108,14 @@ var _ = Describe("Streamer", func() {
Expect(validateVMICalled).To(BeTrue())
})
It("validates the fetched VMI", func() {
streamer.validateVMI = func(vmi *v1.VirtualMachineInstance) *errors.StatusError {
directDialer.validateVMI = func(vmi *v1.VirtualMachineInstance) *errors.StatusError {
Expect(vmi).To(Equal(testVMI))
return nil
}
streamer.Handle(req, resp)
})
It("does not validate the VMI if it can't be fetched", func() {
streamer.fetchVMI = func(_, _ string) (*v1.VirtualMachineInstance, *errors.StatusError) {
directDialer.fetchVMI = func(_, _ string) (*v1.VirtualMachineInstance, *errors.StatusError) {
return nil, errors.NewInternalError(goerrors.New("test error"))
}

@@ -121,14 +127,16 @@ var _ = Describe("Streamer", func() {
Expect(dialCalled).To(BeTrue())
})
It("dials the fetched VMI", func() {
streamer.dial = func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
Expect(vmi).To(Equal(testVMI))
return nil, nil
directDialer.dial = mockDialer{
dialUnderlying: func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
Expect(vmi).To(Equal(testVMI))
return nil, nil
},
}
streamer.Handle(req, resp)
})
It("does not dial when VMI is invalid", func() {
streamer.validateVMI = func(_ *v1.VirtualMachineInstance) *errors.StatusError {
directDialer.validateVMI = func(_ *v1.VirtualMachineInstance) *errors.StatusError {
return errors.NewInternalError(goerrors.New("test error"))
}

@@ -145,8 +153,10 @@ var _ = Describe("Streamer", func() {
defer ws.Close()
})
It("does not attempt the client connection upgrade on a failed dial", func() {
streamer.dial = func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
return nil, errors.NewInternalError(goerrors.New("test error"))
directDialer.dial = mockDialer{
dialUnderlying: func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
return nil, errors.NewInternalError(goerrors.New("test error"))
},
}
srv, _, wsResp, err := testWebsocketDial(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
handleErr := streamer.Handle(restful.NewRequest(r), restful.NewResponse(rw))
@@ -407,3 +417,16 @@ func testWebsocketDial(handler http.HandlerFunc) (*httptest.Server, *websocket.C
ws, resp, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(srv.URL, "http"), nil)
return srv, ws, resp, err
}

type mockDialer struct {
dial func(vmi *v1.VirtualMachineInstance) (*websocket.Conn, *errors.StatusError)
dialUnderlying func(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError)
}

func (m mockDialer) Dial(vmi *v1.VirtualMachineInstance) (*websocket.Conn, *errors.StatusError) {
return m.dial(vmi)
}

func (m mockDialer) DialUnderlying(vmi *v1.VirtualMachineInstance) (net.Conn, *errors.StatusError) {
return m.dialUnderlying(vmi)
}
7 changes: 7 additions & 0 deletions staging/src/kubevirt.io/client-go/kubecli/streamer.go
Original file line number Diff line number Diff line change
@@ -53,3 +53,10 @@ func (c *wsConn) SetDeadline(t time.Time) error {
}
return c.Conn.SetReadDeadline(t)
}

func NewWebsocketStreamer(conn *websocket.Conn, done chan struct{}) *wsStreamer {
return &wsStreamer{
conn: conn,
done: done,
}
}

0 comments on commit 74165c2

Please sign in to comment.