Skip to content

Commit

Permalink
Merge pull request mdlayher#32 from mdlayher/mdl-contextid
Browse files Browse the repository at this point in the history
vsock: export ContextID for checking host context ID and vsock support
  • Loading branch information
mdlayher authored Mar 19, 2019
2 parents 4b4b46f + 33f548a commit ce2162f
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 12 deletions.
17 changes: 14 additions & 3 deletions ioctl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@ type fs interface {
Ioctl(fd uintptr, request int, argp unsafe.Pointer) error
}

// localContextID retrieves the local context ID for this system, using the
// methods from fs. The context ID is stored in cid for later use.
// contextID retrieves the local context ID for this system.
func contextID() (uint32, error) {
// Fetch the context ID using a real filesystem.
var cid uint32
if err := sysContextID(sysFS{}, &cid); err != nil {
return 0, err
}

return cid, nil
}

// sysContextID retrieves the local context ID for this system, using the
// methods from fs. The context ID is stored in cid for later use.
//
// This method uses this signature to enable easier testing without unsafe
// usage of unsafe.Pointer.
func localContextID(fs fs, cid *uint32) error {
func sysContextID(fs fs, cid *uint32) error {
f, err := fs.Open(devVsock)
if err != nil {
return err
Expand Down
18 changes: 11 additions & 7 deletions ioctl_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func Test_localContextIDGuest(t *testing.T) {
ioctl: ioctl,
}

if err := localContextID(fs, &cid); err != nil {
if err := sysContextID(fs, &cid); err != nil {
t.Fatalf("failed to retrieve host's context ID: %v", err)
}

Expand All @@ -70,11 +70,13 @@ func Test_localContextIDGuestIntegration(t *testing.T) {
t.Skip("machine is not a guest, skipping")
}

var cid uint32
if err := localContextID(sysFS{}, &cid); err != nil {
cid, err := ContextID()
if err != nil {
t.Fatalf("failed to retrieve guest's context ID: %v", err)
}

t.Logf("guest context ID: %d", cid)

// Guests should always have a context ID of 3 or more, since
// 0-2 are invalid or reserved.
if cid < 3 {
Expand All @@ -87,11 +89,13 @@ func Test_localContextIDHostIntegration(t *testing.T) {
t.Skip("machine is not a hypervisor, skipping")
}

var cid uint32
if err := localContextID(sysFS{}, &cid); err != nil {
cid, err := ContextID()
if err != nil {
t.Fatalf("failed to retrieve host's context ID: %v", err)
}

t.Logf("host context ID: %d", cid)

if want, got := uint32(Host), cid; want != got {
t.Fatalf("unexpected host context ID:\n- want: %d\n- got: %d",
want, got)
Expand All @@ -106,8 +110,8 @@ func isHypervisor(t *testing.T) bool {
t.Skipf("device %q not available, kernel module not loaded?", devVsock)
}

var cid uint32
if err := localContextID(sysFS{}, &cid); err != nil {
cid, err := ContextID()
if err != nil {
if os.IsPermission(err) {
t.Skipf("permission denied, make sure user has access to %q", devVsock)
}
Expand Down
4 changes: 2 additions & 2 deletions listener_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ func (l *listener) Accept() (net.Conn, error) {

// listenStream is the entry point for ListenStream on Linux.
func listenStream(port uint32) (*Listener, error) {
var cid uint32
if err := localContextID(sysFS{}, &cid); err != nil {
cid, err := ContextID()
if err != nil {
return nil, err
}

Expand Down
10 changes: 10 additions & 0 deletions vsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,13 @@ func (a *Addr) String() string {
func (a *Addr) fileName() string {
return fmt.Sprintf("%s:%s", a.Network(), a.String())
}

// ContextID retrieves the local VM sockets context ID for this system.
// ContextID can be used to directly determine if a system is capable of using
// VM sockets.
//
// If the kernel module is unavailable, access to the kernel module is denied,
// or VM sockets are unsupported on this system, it returns an error.
func ContextID() (uint32, error) {
return contextID()
}
2 changes: 2 additions & 0 deletions vsock_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ func (*connFD) Read(_ []byte) (int, error) { return 0, errUni
func (*connFD) Write(_ []byte) (int, error) { return 0, errUnimplemented }
func (*connFD) Close() error { return errUnimplemented }
func (*connFD) Shutdown(_ int) error { return errUnimplemented }

func contextID() (uint32, error) { return 0, errUnimplemented }
5 changes: 5 additions & 0 deletions vsock_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import "testing"
func TestUnimplemented(t *testing.T) {
want := errUnimplemented

if _, got := ContextID(); want != got {
t.Fatalf("unexpected error from ContextID:\n- want: %v\n- got: %v",
want, got)
}

if _, got := listenStream(0); want != got {
t.Fatalf("unexpected error from listenStream:\n- want: %v\n- got: %v",
want, got)
Expand Down

0 comments on commit ce2162f

Please sign in to comment.