diff --git a/ioctl_linux.go b/ioctl_linux.go index 10959f3..2e59120 100644 --- a/ioctl_linux.go +++ b/ioctl_linux.go @@ -20,12 +20,23 @@ type fs interface { Ioctl(fd uintptr, request int, argp unsafe.Pointer) error } -// localContextID retrieves the local context ID for this system, using the -// methods from fs. The context ID is stored in cid for later use. +// contextID retrieves the local context ID for this system. +func contextID() (uint32, error) { + // Fetch the context ID using a real filesystem. + var cid uint32 + if err := sysContextID(sysFS{}, &cid); err != nil { + return 0, err + } + + return cid, nil +} + +// sysContextID retrieves the local context ID for this system, using the +// methods from fs. The context ID is stored in cid for later use. // // This method uses this signature to enable easier testing without unsafe // usage of unsafe.Pointer. -func localContextID(fs fs, cid *uint32) error { +func sysContextID(fs fs, cid *uint32) error { f, err := fs.Open(devVsock) if err != nil { return err diff --git a/ioctl_linux_test.go b/ioctl_linux_test.go index 76486d2..3f43f0e 100644 --- a/ioctl_linux_test.go +++ b/ioctl_linux_test.go @@ -55,7 +55,7 @@ func Test_localContextIDGuest(t *testing.T) { ioctl: ioctl, } - if err := localContextID(fs, &cid); err != nil { + if err := sysContextID(fs, &cid); err != nil { t.Fatalf("failed to retrieve host's context ID: %v", err) } @@ -70,11 +70,13 @@ func Test_localContextIDGuestIntegration(t *testing.T) { t.Skip("machine is not a guest, skipping") } - var cid uint32 - if err := localContextID(sysFS{}, &cid); err != nil { + cid, err := ContextID() + if err != nil { t.Fatalf("failed to retrieve guest's context ID: %v", err) } + t.Logf("guest context ID: %d", cid) + // Guests should always have a context ID of 3 or more, since // 0-2 are invalid or reserved. if cid < 3 { @@ -87,11 +89,13 @@ func Test_localContextIDHostIntegration(t *testing.T) { t.Skip("machine is not a hypervisor, skipping") } - var cid uint32 - if err := localContextID(sysFS{}, &cid); err != nil { + cid, err := ContextID() + if err != nil { t.Fatalf("failed to retrieve host's context ID: %v", err) } + t.Logf("host context ID: %d", cid) + if want, got := uint32(Host), cid; want != got { t.Fatalf("unexpected host context ID:\n- want: %d\n- got: %d", want, got) @@ -106,8 +110,8 @@ func isHypervisor(t *testing.T) bool { t.Skipf("device %q not available, kernel module not loaded?", devVsock) } - var cid uint32 - if err := localContextID(sysFS{}, &cid); err != nil { + cid, err := ContextID() + if err != nil { if os.IsPermission(err) { t.Skipf("permission denied, make sure user has access to %q", devVsock) } diff --git a/listener_linux.go b/listener_linux.go index 1dd7c5a..ec928a6 100644 --- a/listener_linux.go +++ b/listener_linux.go @@ -42,8 +42,8 @@ func (l *listener) Accept() (net.Conn, error) { // listenStream is the entry point for ListenStream on Linux. func listenStream(port uint32) (*Listener, error) { - var cid uint32 - if err := localContextID(sysFS{}, &cid); err != nil { + cid, err := ContextID() + if err != nil { return nil, err } diff --git a/vsock.go b/vsock.go index 8f7f167..516917c 100644 --- a/vsock.go +++ b/vsock.go @@ -170,3 +170,13 @@ func (a *Addr) String() string { func (a *Addr) fileName() string { return fmt.Sprintf("%s:%s", a.Network(), a.String()) } + +// ContextID retrieves the local VM sockets context ID for this system. +// ContextID can be used to directly determine if a system is capable of using +// VM sockets. +// +// If the kernel module is unavailable, access to the kernel module is denied, +// or VM sockets are unsupported on this system, it returns an error. +func ContextID() (uint32, error) { + return contextID() +} diff --git a/vsock_others.go b/vsock_others.go index 05af265..882191a 100644 --- a/vsock_others.go +++ b/vsock_others.go @@ -36,3 +36,5 @@ func (*connFD) Read(_ []byte) (int, error) { return 0, errUni func (*connFD) Write(_ []byte) (int, error) { return 0, errUnimplemented } func (*connFD) Close() error { return errUnimplemented } func (*connFD) Shutdown(_ int) error { return errUnimplemented } + +func contextID() (uint32, error) { return 0, errUnimplemented } diff --git a/vsock_others_test.go b/vsock_others_test.go index b1afeb0..d872109 100644 --- a/vsock_others_test.go +++ b/vsock_others_test.go @@ -7,6 +7,11 @@ import "testing" func TestUnimplemented(t *testing.T) { want := errUnimplemented + if _, got := ContextID(); want != got { + t.Fatalf("unexpected error from ContextID:\n- want: %v\n- got: %v", + want, got) + } + if _, got := listenStream(0); want != got { t.Fatalf("unexpected error from listenStream:\n- want: %v\n- got: %v", want, got)