forked from mdlayher/vsock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conn_linux_test.go
152 lines (127 loc) · 3.03 KB
/
conn_linux_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
//+build linux
package vsock
import (
"errors"
"syscall"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/sys/unix"
)
func Test_dialLinuxErrorClosesFile(t *testing.T) {
var closed bool
cfd := &testConnFD{
// Track when fd.Close is called.
close: func() error {
closed = true
return nil
},
// Always return an error on connect.
connect: func(sa unix.Sockaddr) error {
return errors.New("error during connect")
},
}
if _, err := dialLinux(cfd, 0, 0); err == nil {
t.Fatal("expected an error, but none occurred")
}
if diff := cmp.Diff(true, closed); diff != "" {
t.Fatalf("unexpected closed value (-want +got):\n%s", diff)
}
}
func Test_dialLinuxFull(t *testing.T) {
const (
localCID uint32 = 3
localPort uint32 = 1024
remoteCID uint32 = Host
remotePort uint32 = 2048
)
lsa := &unix.SockaddrVM{
CID: localCID,
Port: localPort,
}
rsa := &unix.SockaddrVM{
CID: remoteCID,
Port: remotePort,
}
var (
closed bool
closedRead bool
closedWrite bool
syscallConn bool
)
cfd := &testConnFD{
connect: func(sa unix.Sockaddr) error {
if diff := cmp.Diff(rsa, sa.(*unix.SockaddrVM), cmp.AllowUnexported(*rsa)); diff != "" {
t.Fatalf("unexpected connect sockaddr (-want +got):\n%s", diff)
}
return nil
},
getsockname: func() (unix.Sockaddr, error) {
return lsa, nil
},
setNonblocking: func(name string) error {
if diff := cmp.Diff(name, "vsock:vm(3):1024"); diff != "" {
t.Fatalf("unexpected non-blocking file name (-want +got):\n%s", diff)
}
return nil
},
close: func() error {
closed = true
return nil
},
shutdown: func(how int) error {
switch how {
case unix.SHUT_RD:
closedRead = true
case unix.SHUT_WR:
closedWrite = true
default:
t.Fatalf("unexpected how constant in shutdown: %d", how)
}
return nil
},
syscallConn: func() (syscall.RawConn, error) {
// No need to really do anything.
syscallConn = true
return nil, nil
},
}
c, err := dialLinux(cfd, remoteCID, remotePort)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
localAddr := &Addr{
ContextID: localCID,
Port: localPort,
}
if diff := cmp.Diff(localAddr, c.LocalAddr()); diff != "" {
t.Fatalf("unexpected local address (-want +got):\n%s", diff)
}
remoteAddr := &Addr{
ContextID: remoteCID,
Port: remotePort,
}
if diff := cmp.Diff(remoteAddr, c.RemoteAddr()); diff != "" {
t.Fatalf("unexpected remote address (-want +got):\n%s", diff)
}
if _, err := c.SyscallConn(); err != nil {
t.Fatalf("failed to test syscall conn: %v", err)
}
if !syscallConn {
t.Fatal("expected call to SyscallConn, but none occurred")
}
// Verify Close/Shutdown plumbing.
funcs := []func() error{
c.Close,
c.CloseRead,
c.CloseWrite,
}
for i, fn := range funcs {
if err := fn(); err != nil {
t.Fatalf("failed to invoke function %d: %v", i, err)
}
}
if !closed || !closedRead || !closedWrite {
t.Fatalf("expected calls to Close (%t), CloseRead (%t), and CloseWrite (%t)",
closed, closedRead, closedWrite)
}
}