Skip to content

Commit

Permalink
Implement Broadcast support
Browse files Browse the repository at this point in the history
This change adds support for the SO_BROADCAST socket option in gVisor Netstack.
This support includes getsockopt()/setsockopt() functionality for both UDP and
TCP endpoints (the latter being a NOOP), dispatching broadcast messages up and
down the stack, and route finding/creation for broadcast packets. Finally, a
suite of tests have been implemented, exercising this functionality through the
Linux syscall API.

PiperOrigin-RevId: 234850781
Change-Id: If3e666666917d39f55083741c78314a06defb26c
  • Loading branch information
amanda-tait authored and shentubot committed Feb 20, 2019
1 parent 3e3a1ef commit ea070b9
Show file tree
Hide file tree
Showing 22 changed files with 698 additions and 11 deletions.
3 changes: 3 additions & 0 deletions pkg/dhcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg
}, nil); err != nil {
return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
if err := ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
return Config{}, fmt.Errorf("dhcp: setsockopt SO_BROADCAST: %v", err)
}

epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions pkg/dhcp/dhcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ func TestTwoServers(t *testing.T) {
if err = ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil {
t.Fatalf("dhcp: server bind: %v", err)
}
if err = ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
t.Fatalf("dhcp: setsockopt: %v", err)
}

serverCtx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
3 changes: 3 additions & 0 deletions pkg/dhcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ func newEPConnServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Addr
if err := ep.Bind(tcpip.FullAddress{Port: ServerPort}, nil); err != nil {
return nil, fmt.Errorf("dhcp: server bind: %v", err)
}
if err := ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
return nil, fmt.Errorf("dhcp: server setsockopt: %v", err)
}
c := newEPConn(ctx, wq, ep)
return NewServer(ctx, c, addrs, cfg)
}
Expand Down
21 changes: 21 additions & 0 deletions pkg/sentry/socket/epsocket/epsocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,

// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) {
// TODO: Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_TYPE:
if outLen < sizeOfInt32 {
Expand Down Expand Up @@ -681,6 +682,18 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family

return int32(v), nil

case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}

var v tcpip.BroadcastOption
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}

return int32(v), nil

case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
Expand Down Expand Up @@ -982,6 +995,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))

case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}

v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v)))

case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
Expand Down
2 changes: 2 additions & 0 deletions pkg/syserr/netstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
ErrQueueSizeNotSupported = New(tcpip.ErrQueueSizeNotSupported.String(), linux.ENOTTY)
ErrNoSuchFile = New(tcpip.ErrNoSuchFile.String(), linux.ENOENT)
ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL)
ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES)
)

var netstackErrorTranslations = map[*tcpip.Error]*Error{
Expand Down Expand Up @@ -80,6 +81,7 @@ var netstackErrorTranslations = map[*tcpip.Error]*Error{
tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable,
tcpip.ErrMessageTooLong: ErrMessageTooLong,
tcpip.ErrNoBufferSpace: ErrNoBufferSpace,
tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled,
}

// TranslateNetstackError converts an error from the tcpip package to a sentry
Expand Down
15 changes: 15 additions & 0 deletions pkg/tcpip/stack/nic.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,21 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr

src, dst := netProto.ParseAddresses(vv.First())

// If the packet is destined to the IPv4 Broadcast address, then make a
// route to each IPv4 network endpoint and let each endpoint handle the
// packet.
if dst == header.IPv4Broadcast {
for _, ref := range n.endpoints {
if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
r.RemoteLinkAddress = remote
ref.ep.HandlePacket(&r, vv)
ref.decRef()
}
}
return
}

if ref := n.getRef(protocol, dst); ref != nil {
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
r.RemoteLinkAddress = remote
Expand Down
70 changes: 59 additions & 11 deletions pkg/tcpip/stack/transport_demuxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,22 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
ep.selectEndpoint(id).HandlePacket(r, id, vv)
// If this is a broadcast datagram, deliver the datagram to all endpoints
// managed by ep.
if id.LocalAddress == header.IPv4Broadcast {
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy.
if i == len(ep.endpointsArr)-1 {
endpoint.HandlePacket(r, id, vv)
break
}
vvCopy := buffer.NewView(vv.Size())
copy(vvCopy, vv.ToView())
endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
} else {
ep.selectEndpoint(id).HandlePacket(r, id, vv)
}
}

// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
Expand Down Expand Up @@ -224,20 +239,47 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
}
}

// deliverPacket attempts to deliver the given packet. Returns true if it found
// an endpoint, false otherwise.
var loopbackSubnet = func() tcpip.Subnet {
sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
if err != nil {
panic(err)
}
return sn
}()

// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if it
// found one or more endpoints, false otherwise.
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}

// If a sender bound to the Loopback interface sends a broadcast,
// that broadcast must not be delivered to the sender.
if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
return false
}

// If the packet is a broadcast, then find all matching transport endpoints.
// Otherwise, try to find a single matching transport endpoint.
destEps := make([]TransportEndpoint, 0, 1)
eps.mu.RLock()
ep := d.findEndpointLocked(eps, vv, id)

if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
for epID, endpoint := range eps.endpoints {
if epID.LocalPort == id.LocalPort {
destEps = append(destEps, endpoint)
}
}
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
eps.mu.RUnlock()

// Fail if we didn't find one.
if ep == nil {
// Fail if we didn't find at least one matching transport endpoint.
if len(destEps) == 0 {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
Expand All @@ -246,7 +288,9 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
}

// Deliver the packet.
ep.HandlePacket(r, id, vv)
for _, ep := range destEps {
ep.HandlePacket(r, id, vv)
}

return true
}
Expand Down Expand Up @@ -277,27 +321,31 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,

func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
// Try to find a match with the id as provided.
if ep := eps.endpoints[id]; ep != nil {
if ep, ok := eps.endpoints[id]; ok {
return ep
}

// Try to find a match with the id minus the local address.
nid := id

nid.LocalAddress = ""
if ep := eps.endpoints[nid]; ep != nil {
if ep, ok := eps.endpoints[nid]; ok {
return ep
}

// Try to find a match with the id minus the remote part.
nid.LocalAddress = id.LocalAddress
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep := eps.endpoints[nid]; ep != nil {
if ep, ok := eps.endpoints[nid]; ok {
return ep
}

// Try to find a match with only the local port.
nid.LocalAddress = ""
return eps.endpoints[nid]
if ep, ok := eps.endpoints[nid]; ok {
return ep
}

return nil
}
11 changes: 11 additions & 0 deletions pkg/tcpip/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ var (
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
ErrMessageTooLong = &Error{msg: "message too long"}
ErrNoBufferSpace = &Error{msg: "no buffer space available"}
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
)

// Errors related to Subnet
Expand Down Expand Up @@ -502,6 +503,10 @@ type RemoveMembershipOption MembershipOption
// TCP out-of-band data is delivered along with the normal in-band data.
type OutOfBandInlineOption int

// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
// datagram sockets are allowed to send packets to a broadcast address.
type BroadcastOption int

// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination adddress in the row.
Expand All @@ -527,6 +532,12 @@ func (r *Route) Match(addr Address) bool {
return false
}

// Using header.Ipv4Broadcast would introduce an import cycle, so
// we'll use a literal instead.
if addr == "\xff\xff\xff\xff" {
return true
}

for i := 0; i < len(r.Destination); i++ {
if (addr[i] & r.Mask[i]) != r.Destination[i] {
return false
Expand Down
20 changes: 20 additions & 0 deletions pkg/tcpip/transport/tcp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ type endpoint struct {
route stack.Route `state:"manual"`
v6only bool
isConnectNotified bool
// TCP should never broadcast but Linux nevertheless supports enabling/
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool

// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
Expand Down Expand Up @@ -813,6 +816,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
return nil

case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
e.mu.Unlock()
return nil

default:
return nil
}
Expand Down Expand Up @@ -971,6 +980,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 1
return nil

case *tcpip.BroadcastOption:
e.mu.Lock()
v := e.broadcast
e.mu.Unlock()

*o = 0
if v {
*o = 1
}
return nil

default:
return tcpip.ErrUnknownProtocolOption
}
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/transport/tcp/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ func loadError(s string) *tcpip.Error {
tcpip.ErrNetworkUnreachable,
tcpip.ErrMessageTooLong,
tcpip.ErrNoBufferSpace,
tcpip.ErrBroadcastDisabled,
}

messageToError = make(map[string]*tcpip.Error)
Expand Down
23 changes: 23 additions & 0 deletions pkg/tcpip/transport/udp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
reusePort bool
broadcast bool

// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
Expand Down Expand Up @@ -347,6 +348,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
nicid = e.bindNICID
}

if to.Addr == header.IPv4Broadcast && !e.broadcast {
return 0, nil, tcpip.ErrBroadcastDisabled
}

r, _, _, err := e.connectRoute(nicid, *to)
if err != nil {
return 0, nil, err
Expand Down Expand Up @@ -502,6 +507,13 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Lock()
e.reusePort = v != 0
e.mu.Unlock()

case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
e.mu.Unlock()

return nil
}
return nil
}
Expand Down Expand Up @@ -581,6 +593,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 0
return nil

case *tcpip.BroadcastOption:
e.mu.RLock()
v := e.broadcast
e.mu.RUnlock()

*o = 0
if v {
*o = 1
}
return nil

default:
return tcpip.ErrUnknownProtocolOption
}
Expand Down
Loading

0 comments on commit ea070b9

Please sign in to comment.