Skip to content

Commit

Permalink
Merge pull request quic-go#1612 from lucas-clemente/stateless-reset-r…
Browse files Browse the repository at this point in the history
…eceiving

implement receiving of stateless resets
  • Loading branch information
marten-seemann authored Nov 20, 2018
2 parents 54c287c + 39e1e9a commit 9edd783
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 17 deletions.
3 changes: 3 additions & 0 deletions internal/protocol/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ const MaxPacketSizeIPv4 = 1252
// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
const MaxPacketSizeIPv6 = 1232

// MinStatelessResetSize is the minimum size of a stateless reset packet
const MinStatelessResetSize = 1 + 20 + 16

// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
const NonForwardSecurePacketSizeReduction = 50
Expand Down
69 changes: 52 additions & 17 deletions packet_handler_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package quic

import (
"bytes"
"errors"
"fmt"
"net"
"sync"
Expand All @@ -12,6 +13,11 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)

type packetHandlerEntry struct {
handler packetHandler
resetToken *[16]byte
}

// The packetHandlerMap stores packetHandlers, identified by connection ID.
// It is used:
// * by the server to store sessions
Expand All @@ -22,9 +28,10 @@ type packetHandlerMap struct {
conn net.PacketConn
connIDLen int

handlers map[string] /* string(ConnectionID)*/ packetHandler
server unknownPacketHandler
closed bool
handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry
resetTokens map[[16]byte] /* stateless reset token */ packetHandler
server unknownPacketHandler
closed bool

deleteRetiredSessionsAfter time.Duration

Expand All @@ -37,7 +44,8 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
handlers: make(map[string]packetHandler),
handlers: make(map[string]packetHandlerEntry),
resetTokens: make(map[[16]byte]packetHandler),
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
logger: logger,
}
Expand All @@ -47,13 +55,29 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger

func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
h.mutex.Lock()
h.handlers[string(id)] = handler
h.handlers[string(id)] = packetHandlerEntry{handler: handler}
h.mutex.Unlock()
}

func (h *packetHandlerMap) AddWithResetToken(id protocol.ConnectionID, handler packetHandler, token [16]byte) {
h.mutex.Lock()
h.handlers[string(id)] = packetHandlerEntry{handler: handler, resetToken: &token}
h.resetTokens[token] = handler
h.mutex.Unlock()
}

func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.removeByConnectionIDAsString(string(id))
}

func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
h.mutex.Lock()
delete(h.handlers, string(id))
if handlerEntry, ok := h.handlers[id]; ok {
if token := handlerEntry.resetToken; token != nil {
delete(h.resetTokens, *token)
}
delete(h.handlers, id)
}
h.mutex.Unlock()
}

Expand All @@ -63,9 +87,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {

func (h *packetHandlerMap) retireByConnectionIDAsString(id string) {
time.AfterFunc(h.deleteRetiredSessionsAfter, func() {
h.mutex.Lock()
delete(h.handlers, id)
h.mutex.Unlock()
h.removeByConnectionIDAsString(id)
})
}

Expand All @@ -79,7 +101,8 @@ func (h *packetHandlerMap) CloseServer() {
h.mutex.Lock()
h.server = nil
var wg sync.WaitGroup
for id, handler := range h.handlers {
for id, handlerEntry := range h.handlers {
handler := handlerEntry.handler
if handler.GetPerspective() == protocol.PerspectiveServer {
wg.Add(1)
go func(id string, handler packetHandler) {
Expand All @@ -103,12 +126,12 @@ func (h *packetHandlerMap) close(e error) error {
h.closed = true

var wg sync.WaitGroup
for _, handler := range h.handlers {
for _, handlerEntry := range h.handlers {
wg.Add(1)
go func(handler packetHandler) {
handler.destroy(e)
go func(handlerEntry packetHandlerEntry) {
handlerEntry.handler.destroy(e)
wg.Done()
}(handler)
}(handlerEntry)
}

if h.server != nil {
Expand Down Expand Up @@ -149,25 +172,37 @@ func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
}

h.mutex.RLock()
handler, ok := h.handlers[string(iHdr.DestConnectionID)]
handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)]
server := h.server
h.mutex.RUnlock()

var sentBy protocol.Perspective
var version protocol.VersionNumber
var handlePacket func(*receivedPacket)
if ok { // existing session
if handlerFound { // existing session
handler := handlerEntry.handler
sentBy = handler.GetPerspective().Opposite()
version = handler.GetVersion()
handlePacket = handler.handlePacket
} else { // no session found
// this might be a stateless reset
if !iHdr.IsLongHeader && len(data) >= protocol.MinStatelessResetSize {
var token [16]byte
copy(token[:], data[len(data)-16:])
if sess, ok := h.resetTokens[token]; ok {
h.mutex.RUnlock()
sess.destroy(errors.New("received a stateless reset"))
return nil
}
}
if server == nil { // no server set
h.mutex.RUnlock()
return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
}
handlePacket = server.handlePacket
sentBy = protocol.PerspectiveClient
version = iHdr.Version
}
h.mutex.RUnlock()

hdr, err := iHdr.Parse(r, sentBy, version)
if err != nil {
Expand Down
48 changes: 48 additions & 0 deletions packet_handler_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,54 @@ var _ = Describe("Packet Handler Map", func() {
})
})

Context("stateless reset handling", func() {
It("handles packets for connections added with a reset token", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(connID, packetHandler, token)
// first send a normal packet
handledPacket := make(chan struct{})
packetHandler.EXPECT().GetPerspective()
packetHandler.EXPECT().GetVersion()
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
Expect(p.header.DestConnectionID).To(Equal(connID))
close(handledPacket)
})
conn.dataToRead <- getPacket(connID)
Eventually(handledPacket).Should(BeClosed())
})

It("handles stateless resets", func() {
packetHandler := NewMockPacketHandler(mockCtrl)
connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(connID, packetHandler, token)
packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
destroyed := make(chan struct{})
packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) {
close(destroyed)
})
conn.dataToRead <- packet
Eventually(destroyed).Should(BeClosed())
})

It("deletes reset tokens when the session is retired", func() {
handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond)
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42}
token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token)
handler.Retire(connID)
time.Sleep(scaleDuration(30 * time.Millisecond))
Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0xdeadbeef42"))
packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...)
packet = append(packet, token[:]...)
Expect(handler.handlePacket(nil, packet)).To(MatchError("received a packet with an unexpected connection ID 0xdecafbad99"))
Expect(handler.resetTokens).To(BeEmpty())
})
})

Context("running a server", func() {
It("adds a server", func() {
connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}
Expand Down

0 comments on commit 9edd783

Please sign in to comment.