diff --git a/internal/utils/atomic_bool.go b/internal/utils/atomic_bool.go new file mode 100644 index 00000000000..cf4642504e0 --- /dev/null +++ b/internal/utils/atomic_bool.go @@ -0,0 +1,22 @@ +package utils + +import "sync/atomic" + +// An AtomicBool is an atomic bool +type AtomicBool struct { + v int32 +} + +// Set sets the value +func (a *AtomicBool) Set(value bool) { + var n int32 + if value { + n = 1 + } + atomic.StoreInt32(&a.v, n) +} + +// Get gets the value +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.v) != 0 +} diff --git a/internal/utils/atomic_bool_test.go b/internal/utils/atomic_bool_test.go new file mode 100644 index 00000000000..83a200c2781 --- /dev/null +++ b/internal/utils/atomic_bool_test.go @@ -0,0 +1,29 @@ +package utils + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Atomic Bool", func() { + var a *AtomicBool + + BeforeEach(func() { + a = &AtomicBool{} + }) + + It("has the right default value", func() { + Expect(a.Get()).To(BeFalse()) + }) + + It("sets the value to true", func() { + a.Set(true) + Expect(a.Get()).To(BeTrue()) + }) + + It("sets the value to false", func() { + a.Set(true) + a.Set(false) + Expect(a.Get()).To(BeFalse()) + }) +}) diff --git a/packet_handler_map.go b/packet_handler_map.go index fef9468ed9d..4f3b69507f3 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -56,10 +56,6 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { } func (h *packetHandlerMap) removeByConnectionIDAsString(id string) { - h.mutex.Lock() - h.handlers[id] = nil - h.mutex.Unlock() - time.AfterFunc(h.deleteClosedSessionsAfter, func() { h.mutex.Lock() delete(h.handlers, id) @@ -102,13 +98,11 @@ func (h *packetHandlerMap) close(e error) error { var wg sync.WaitGroup for _, handler := range h.handlers { - if handler != nil { - wg.Add(1) - go func(handler packetHandler) { - handler.destroy(e) - wg.Done() - }(handler) - } + wg.Add(1) + go func(handler packetHandler) { + handler.destroy(e) + wg.Done() + }(handler) } if h.server != nil { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index c5f1b267a3c..bc18c561ec2 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -88,20 +88,23 @@ var _ = Describe("Packet Handler Map", func() { Expect(err.Error()).To(ContainSubstring("error parsing invariant header:")) }) - It("deletes nil session entries after a wait time", func() { + It("deletes closed session entries after a wait time", func() { handler.deleteClosedSessionsAfter = 10 * time.Millisecond connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) - Eventually(func() error { - return handler.handlePacket(nil, getPacket(connID)) - }).Should(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) + time.Sleep(30 * time.Millisecond) + Expect(handler.handlePacket(nil, getPacket(connID))).To(MatchError("received a packet with an unexpected connection ID 0x0102030405060708")) }) - It("ignores packets arriving late for closed sessions", func() { + It("passes packets arriving late for closed sessions to that session", func() { handler.deleteClosedSessionsAfter = time.Hour connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - handler.Add(connID, NewMockPacketHandler(mockCtrl)) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().GetVersion().Return(protocol.VersionWhatever) + packetHandler.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) + packetHandler.EXPECT().handlePacket(gomock.Any()) + handler.Add(connID, packetHandler) handler.Remove(connID) err := handler.handlePacket(nil, getPacket(connID)) Expect(err).ToNot(HaveOccurred()) diff --git a/session.go b/session.go index a55daf8bec9..2304455c1dc 100644 --- a/session.go +++ b/session.go @@ -94,9 +94,13 @@ type session struct { receivedPackets chan *receivedPacket sendingScheduled chan struct{} - // closeChan is used to notify the run loop that it should terminate. - closeChan chan closeError + closeOnce sync.Once + closed utils.AtomicBool + // closeChan is used to notify the run loop that it should terminate + closeChan chan closeError + connectionClosePacket *packedPacket + packetsReceivedAfterClose int ctx context.Context ctxCancel context.CancelFunc @@ -418,6 +422,7 @@ runLoop: if err := s.handleCloseError(closeErr); err != nil { s.logger.Infof("Handling close error failed: %s", err) } + s.closed.Set(true) s.logger.Infof("Connection %s closed.", s.srcConnID) s.sessionRunner.removeConnectionID(s.srcConnID) s.cryptoStreamHandler.Close() @@ -596,6 +601,9 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve // handlePacket is called by the server with a new packet func (s *session) handlePacket(p *receivedPacket) { + if s.closed.Get() { + s.handlePacketAfterClosed(p) + } // Discard packets once the amount of queued packets is larger than // the channel size, protocol.MaxSessionUnprocessedPackets select { @@ -604,6 +612,24 @@ func (s *session) handlePacket(p *receivedPacket) { } } +func (s *session) handlePacketAfterClosed(p *receivedPacket) { + s.packetsReceivedAfterClose++ + if s.connectionClosePacket == nil { + return + } + // exponential backoff + // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving + for n := s.packetsReceivedAfterClose; n > 1; n = n / 2 { + if n%2 != 0 { + return + } + } + s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.packetsReceivedAfterClose) + if err := s.conn.Write(s.connectionClosePacket.raw); err != nil { + s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err) + } +} + func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) if err != nil { @@ -943,6 +969,7 @@ func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { if err != nil { return err } + s.connectionClosePacket = packet s.logPacket(packet) return s.conn.Write(packet.raw) } diff --git a/session_test.go b/session_test.go index 54b94c8df49..7ab2bf1b1f2 100644 --- a/session_test.go +++ b/session_test.go @@ -426,6 +426,24 @@ var _ = Describe("Session", func() { sess.Close() Eventually(returned).Should(BeClosed()) }) + + It("retransmits the CONNECTION_CLOSE packet if packets are arriving late", func() { + streamManager.EXPECT().CloseWithError(gomock.Any()) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{raw: []byte("foobar")}, nil) + sess.Close() + Expect(mconn.written).To(Receive(Equal([]byte("foobar")))) // receive the CONNECTION_CLOSE + Eventually(sess.Context().Done()).Should(BeClosed()) + for i := 1; i <= 20; i++ { + sess.handlePacket(&receivedPacket{}) + if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { + Expect(mconn.written).To(Receive(Equal([]byte("foobar")))) // receive the CONNECTION_CLOSE + } else { + Expect(mconn.written).To(HaveLen(0)) + } + } + }) }) Context("receiving packets", func() {