-
Notifications
You must be signed in to change notification settings - Fork 13
/
socket.go
110 lines (94 loc) · 2.69 KB
/
socket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package internal
import (
"fmt"
"net"
"os"
"code.cfops.it/sys/tubular/internal/sysconn"
"golang.org/x/sys/unix"
)
// WriteToSocket writes a message with an accompanying file descriptor to a
// Unix socket.
//
// It's valid to pass a nil file. Each call to this function performs exactly
// one write on conn.
func WriteToSocket(conn *net.UnixConn, p []byte, file *os.File) (int, error) {
if file == nil {
return conn.Write(p)
}
return sysconn.ControlInt(file, func(fd int) (int, error) {
oob := unix.UnixRights(fd)
n, oobn, err := conn.WriteMsgUnix(p, oob, nil)
if err != nil {
return n, err
}
if oobn != len(oob) {
return n, fmt.Errorf("short write of out-of-band data")
}
return n, nil
})
}
// ReadFromSocket reads a message, between zero and one file descriptors and
// the senders uid from a Unix socket.
//
// file is optional and may be nil. The function requires SO_PASSCRED to be
// set on conn, so uid is always valid if no error is returned.
//
// Each call to this function performs exactly one read on conn.
func ReadFromSocket(conn *net.UnixConn, p []byte) (n, uid int, file *os.File, err error) {
const sizeofInt32 = 4
rightsLen := unix.CmsgSpace(1 * sizeofInt32)
credsLen := unix.CmsgSpace(unix.SizeofUcred)
oob := make([]byte, rightsLen+credsLen)
n, oobn, _, _, err := conn.ReadMsgUnix(p, oob)
if err != nil {
return 0, 0, nil, err
}
scms, err := unix.ParseSocketControlMessage(oob[:oobn])
if err != nil {
return 0, 0, nil, fmt.Errorf("parse control messages: %s", err)
}
var creds *unix.Ucred
// Don't bail out while processing SCMs, we need to make sure that we don't
// leak file descriptors.
for _, scm := range scms {
if scm.Header.Level != unix.SOL_SOCKET {
err = fmt.Errorf("unrecognised cmsg level: %d", scm.Header.Level)
continue
}
switch scm.Header.Type {
case unix.SCM_CREDENTIALS:
creds, err = unix.ParseUnixCredentials(&scm)
if err != nil {
err = fmt.Errorf("parse credentials: %s", err)
continue
}
case unix.SCM_RIGHTS:
var rights []int
rights, err = unix.ParseUnixRights(&scm)
if err != nil {
err = fmt.Errorf("parse rights: %s", err)
continue
}
if len(rights) > 1 || file != nil {
for _, fd := range rights {
// Don't let the remote end flood us with fds
unix.Close(fd)
}
err = fmt.Errorf("can't handle more than one file descriptor")
continue
}
file = os.NewFile(uintptr(rights[0]), "cmsg fd")
default:
err = fmt.Errorf("unrecognised cmsg type: %d", scm.Header.Type)
}
}
if err != nil {
file.Close()
return 0, 0, nil, err
}
if creds == nil {
file.Close()
return 0, 0, nil, fmt.Errorf("missing credentials")
}
return n, int(creds.Uid), file, nil
}