diff --git a/packet_handler_map.go b/packet_handler_map.go index 0f77dd9de8c..114d2a18e03 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -173,6 +173,14 @@ func (h *packetHandlerMap) handlePacket( return fmt.Errorf("error parsing header: %s", err) } + if hdr.IsLongHeader { + if protocol.ByteCount(r.Len()) < hdr.Length { + return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) + } + data = data[:int(hdr.ParsedLen()+hdr.Length)] + // TODO(#1312): implement parsing of compound packets + } + p := &receivedPacket{ remoteAddr: addr, hdr: hdr, diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 31bc7513f51..8a0141d3d98 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -19,21 +19,25 @@ var _ = Describe("Packet Handler Map", func() { conn *mockPacketConn ) - getPacket := func(connID protocol.ConnectionID) []byte { + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte { buf := &bytes.Buffer{} Expect((&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, DestConnectionID: connID, - Length: 1, + Length: length, Version: protocol.VersionTLS, }, - PacketNumberLen: protocol.PacketNumberLen1, + PacketNumberLen: protocol.PacketNumberLen2, }).Write(buf, protocol.VersionWhatever)).To(Succeed()) return buf.Bytes() } + getPacket := func(connID protocol.ConnectionID) []byte { + return getPacketWithLength(connID, 1) + } + BeforeEach(func() { conn = newMockPacketConn() handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap) @@ -131,6 +135,24 @@ var _ = Describe("Packet Handler Map", func() { conn.Close() Eventually(done).Should(BeClosed()) }) + + It("errors on packets that are smaller than the length in the packet header", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 1000), make([]byte, 500-2 /* for packet number length */)...) + err := handler.handlePacket(nil, nil, data) + Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) + }) + + It("cuts packets to the right length", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) + }) + handler.Add(connID, packetHandler) + Expect(handler.handlePacket(nil, nil, data)).To(Succeed()) + }) }) Context("stateless reset handling", func() { diff --git a/packet_unpacker.go b/packet_unpacker.go index b9eb8ea0154..cda22232466 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -39,14 +39,6 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) func (u *packetUnpacker) Unpack(hdr *wire.Header, data []byte) (*unpackedPacket, error) { r := bytes.NewReader(data) - if hdr.IsLongHeader { - if protocol.ByteCount(r.Len()) < hdr.Length { - return nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length) - } - data = data[:int(hdr.ParsedLen()+hdr.Length)] - // TODO(#1312): implement parsing of compound packets - } - var encLevel protocol.EncryptionLevel switch hdr.Type { case protocol.PacketTypeInitial: diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 8aa4e446b9b..c9a085d08f7 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -75,49 +75,6 @@ var _ = Describe("Packet Unpacker", func() { Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial)) }) - It("errors on packets that are smaller than the length in the packet header", func() { - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Length: 1000, - DestConnectionID: connID, - Version: version, - }, - PacketNumberLen: protocol.PacketNumberLen2, - } - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, 500-2 /* for packet number length */)...) - _, err := unpacker.Unpack(hdr, data) - Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)")) - }) - - It("cuts packets to the right length", func() { - pnLen := protocol.PacketNumberLen2 - extHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - IsLongHeader: true, - DestConnectionID: connID, - Type: protocol.PacketTypeHandshake, - Length: 456, - Version: protocol.VersionTLS, - }, - PacketNumberLen: pnLen, - } - payloadLen := 456 - int(pnLen) - hdr, hdrRaw := getHeader(extHdr) - data := append(hdrRaw, make([]byte, payloadLen)...) - opener := mocks.NewMockOpener(mockCtrl) - cs.EXPECT().GetOpener(protocol.EncryptionHandshake).Return(opener, nil) - opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), extHdr.PacketNumber, hdrRaw).DoAndReturn(func(_, payload []byte, _ protocol.PacketNumber, _ []byte) ([]byte, error) { - Expect(payload).To(HaveLen(payloadLen)) - return []byte{0}, nil - }) - _, err := unpacker.Unpack(hdr, data) - Expect(err).ToNot(HaveOccurred()) - }) - It("returns the error when getting the sealer fails", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID},