Skip to content

Commit

Permalink
Batch write packets after iptables checks
Browse files Browse the repository at this point in the history
After IPTables checks a batch of packets, we can write packets that are
not dropped or locally destined as a batch instead of individually.

This previously caused a bug since WritePacket* functions expect to take
ownership of passed PacketBuffer{List}. WritePackets assumed the list of
PacketBuffers will not be invalidated when calling WritePacket for each
PacketBuffer in the list, but this is not true. WritePacket may add the
passed PacketBuffer into a different list which would modify the
PacketBuffer in such a way that it no longer points to the next
PacketBuffer to write.

Example: Given a PB list of
    PB_a -> PB_b -> PB_c

WritePackets may be iterating over the list and calling WritePacket for
each PB. When WritePacket takes PB_a, it may add it to a new list which
would update pointers such that PB_a no longer points to PB_b.

Test: integration_test.TestIPTableWritePackets
PiperOrigin-RevId: 355969560
  • Loading branch information
ghananigans authored and gvisor-bot committed Feb 6, 2021
1 parent 120c8e3 commit 83b764d
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 70 deletions.
60 changes: 25 additions & 35 deletions pkg/tcpip/network/ipv4/ipv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,47 +441,37 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
stats.PacketsSent.IncrementBy(uint64(n))
if err != nil {
stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
return n, err
}
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
for pkt := range dropped {
pkts.Remove(pkt)
}

// Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if _, ok := dropped[pkt]; ok {
// The NAT-ed packets may now be destined for us.
locallyDelivered := 0
for pkt := range natPkts {
ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
if ep == nil {
// The NAT-ed packet is still destined for some remote node.
continue
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
// can safely assume the checksum is valid.
ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
n++
continue
}
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
stats.PacketsSent.IncrementBy(uint64(n))
stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
}
n++

// Do not send the locally destined packet out the NIC.
pkts.Remove(pkt)

// Deliver the packet locally.
ep.(*endpoint).handleLocalPacket(pkt, true)
locallyDelivered++

}
stats.PacketsSent.IncrementBy(uint64(n))

// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
stats.PacketsSent.IncrementBy(uint64(written))
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))

// Dropped packets aren't errors, so include them in the return value.
return n + len(dropped), nil
return locallyDelivered + written + len(dropped), err
}

// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
Expand Down
58 changes: 23 additions & 35 deletions pkg/tcpip/network/ipv6/ipv6.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,48 +742,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
stats.PacketsSent.IncrementBy(uint64(n))
if err != nil {
stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
return n, err
}
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
for pkt := range dropped {
pkts.Remove(pkt)
}

// Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if _, ok := dropped[pkt]; ok {
// The NAT-ed packets may now be destined for us.
locallyDelivered := 0
for pkt := range natPkts {
ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
if ep == nil {
// The NAT-ed packet is still destined for some remote node.
continue
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
// can safely assume the checksum is valid.
ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
n++
continue
}
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
stats.PacketsSent.IncrementBy(uint64(n))
stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
}
n++

// Do not send the locally destined packet out the NIC.
pkts.Remove(pkt)

// Deliver the packet locally.
ep.(*endpoint).handleLocalPacket(pkt, true)
locallyDelivered++
}

stats.PacketsSent.IncrementBy(uint64(n))
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
stats.PacketsSent.IncrementBy(uint64(written))
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))

// Dropped packets aren't errors, so include them in the return value.
return n + len(dropped), nil
return locallyDelivered + written + len(dropped), err
}

// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
Expand Down
Loading

0 comments on commit 83b764d

Please sign in to comment.