From bd3497cce9aa906c47d5b7216752f55beed3d9d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jeremy=20Lain=C3=A9?= Date: Sat, 22 Jun 2024 23:12:17 +0200 Subject: [PATCH] Rework packet encoding to support different protocol versions The wire representation of packet types changes in QUIC version 2. Instead of carrying around the wire representation of a packet type, we introduce a `QuicPacketType` enum. --- src/aioquic/asyncio/server.py | 4 +- src/aioquic/quic/connection.py | 106 ++++++++++++---------- src/aioquic/quic/logger.py | 22 ++--- src/aioquic/quic/packet.py | 138 +++++++++++++++++++---------- src/aioquic/quic/packet_builder.py | 48 +++++----- tests/test_connection.py | 10 +-- tests/test_packet.py | 45 ++++------ tests/test_packet_builder.py | 120 ++++++++++++------------- tests/test_recovery_cubic.py | 28 +++--- tests/test_recovery_reno.py | 8 +- 10 files changed, 288 insertions(+), 241 deletions(-) diff --git a/src/aioquic/asyncio/server.py b/src/aioquic/asyncio/server.py index b5f6cd8d7..90178c639 100644 --- a/src/aioquic/asyncio/server.py +++ b/src/aioquic/asyncio/server.py @@ -7,7 +7,7 @@ from ..quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration from ..quic.connection import NetworkAddress, QuicConnection from ..quic.packet import ( - PACKET_TYPE_INITIAL, + QuicPacketType, encode_quic_retry, encode_quic_version_negotiation, pull_quic_header, @@ -86,7 +86,7 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N if ( protocol is None and len(data) >= SMALLEST_MAX_DATAGRAM_SIZE - and header.packet_type == PACKET_TYPE_INITIAL + and header.packet_type == QuicPacketType.INITIAL ): # retry if self._retry is not None: diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index f2b9182b8..3bf872419 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -34,23 +34,18 @@ from .packet import ( CONNECTION_ID_MAX_SIZE, NON_ACK_ELICITING_FRAME_TYPES, - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_ONE_RTT, - PACKET_TYPE_RETRY, - PACKET_TYPE_ZERO_RTT, PROBING_FRAME_TYPES, RETRY_INTEGRITY_TAG_SIZE, STATELESS_RESET_TOKEN_SIZE, QuicErrorCode, QuicFrameType, + QuicPacketType, QuicProtocolVersion, QuicStreamFrame, QuicTransportParameters, get_retry_integrity_tag, get_spin_bit, is_draft_version, - is_long_header, pull_ack_frame, pull_quic_header, pull_quic_transport_parameters, @@ -122,12 +117,12 @@ def dump_cid(cid: bytes) -> str: return binascii.hexlify(cid).decode("ascii") -def get_epoch(packet_type: int) -> tls.Epoch: - if packet_type == PACKET_TYPE_INITIAL: +def get_epoch(packet_type: QuicPacketType) -> tls.Epoch: + if packet_type == QuicPacketType.INITIAL: return tls.Epoch.INITIAL - elif packet_type == PACKET_TYPE_ZERO_RTT: + elif packet_type == QuicPacketType.ZERO_RTT: return tls.Epoch.ZERO_RTT - elif packet_type == PACKET_TYPE_HANDSHAKE: + elif packet_type == QuicPacketType.HANDSHAKE: return tls.Epoch.HANDSHAKE else: return tls.Epoch.ONE_RTT @@ -544,10 +539,10 @@ def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: epoch_packet_types = [] if not self._handshake_confirmed: epoch_packet_types += [ - (tls.Epoch.INITIAL, PACKET_TYPE_INITIAL), - (tls.Epoch.HANDSHAKE, PACKET_TYPE_HANDSHAKE), + (tls.Epoch.INITIAL, QuicPacketType.INITIAL), + (tls.Epoch.HANDSHAKE, QuicPacketType.HANDSHAKE), ] - epoch_packet_types.append((tls.Epoch.ONE_RTT, PACKET_TYPE_ONE_RTT)) + epoch_packet_types.append((tls.Epoch.ONE_RTT, QuicPacketType.ONE_RTT)) for epoch, packet_type in epoch_packet_types: crypto = self._cryptos[epoch] if crypto.send.is_valid(): @@ -619,9 +614,9 @@ def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: packet.packet_type ), "scid": ( - dump_cid(self.host_cid) - if is_long_header(packet.packet_type) - else "" + "" + if packet.packet_type == QuicPacketType.ONE_RTT + else dump_cid(self.host_cid) ), "dcid": dump_cid(self._peer_cid.cid), }, @@ -791,7 +786,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non # contained in a datagram smaller than 1200 bytes. if ( not self._is_client - and header.packet_type == PACKET_TYPE_INITIAL + and header.packet_type == QuicPacketType.INITIAL and len(data) < SMALLEST_MAX_DATAGRAM_SIZE ): if self._quic_logger is not None: @@ -800,7 +795,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non event="packet_dropped", data={ "trigger": "initial_packet_datagram_too_small", - "raw": {"length": buf.capacity - start_off}, + "raw": {"length": header.packet_length}, }, ) return @@ -812,13 +807,16 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non destination_cid_seq = connection_id.sequence_number break if ( - self._is_client or header.packet_type == PACKET_TYPE_HANDSHAKE + self._is_client or header.packet_type == QuicPacketType.HANDSHAKE ) and destination_cid_seq is None: if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "unknown_connection_id"}, + data={ + "trigger": "unknown_connection_id", + "raw": {"length": header.packet_length}, + }, ) return @@ -826,13 +824,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non if ( self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT - and header.version == QuicProtocolVersion.NEGOTIATION + and header.packet_type == QuicPacketType.VERSION_NEGOTIATION and not self._version_negotiation_count ): # version negotiation - versions = [] - while not buf.eof(): - versions.append(buf.pull_uint32()) if self._quic_logger is not None: self._quic_logger.log_event( category="transport", @@ -840,20 +835,24 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non data={ "frames": [], "header": { - "packet_type": "version_negotiation", + "packet_type": self._quic_logger.packet_type( + header.packet_type + ), "scid": dump_cid(header.source_cid), "dcid": dump_cid(header.destination_cid), }, - "raw": {"length": buf.tell() - start_off}, + "raw": {"length": header.packet_length}, }, ) - if self._version in versions: + if self._version in header.supported_versions: self._logger.warning( "Version negotiation packet contains %s" % self._version ) return common = [ - x for x in self._configuration.supported_versions if x in versions + x + for x in self._configuration.supported_versions + if x in header.supported_versions ] chosen_version = common[0] if common else None if self._quic_logger is not None: @@ -861,7 +860,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non category="transport", event="version_information", data={ - "server_versions": versions, + "server_versions": header.supported_versions, "client_versions": self._configuration.supported_versions, "chosen_version": chosen_version, }, @@ -890,12 +889,15 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "unsupported_version"}, + data={ + "trigger": "unsupported_version", + "raw": {"length": header.packet_length}, + }, ) return # handle retry packet - if header.packet_type == PACKET_TYPE_RETRY: + if header.packet_type == QuicPacketType.RETRY: if ( self._is_client and not self._retry_count @@ -920,7 +922,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non "scid": dump_cid(header.source_cid), "dcid": dump_cid(header.destination_cid), }, - "raw": {"length": buf.tell() - start_off}, + "raw": {"length": header.packet_length}, }, ) @@ -938,7 +940,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "unexpected_packet"}, + data={ + "trigger": "unexpected_packet", + "raw": {"length": header.packet_length}, + }, ) return @@ -948,7 +953,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non # server initialization if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT: assert ( - header.packet_type == PACKET_TYPE_INITIAL + header.packet_type == QuicPacketType.INITIAL ), "first packet must be INITIAL" crypto_frame_required = True self._network_paths = [network_path] @@ -965,7 +970,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non # decrypt packet encrypted_off = buf.tell() - start_off - end_off = buf.tell() + header.rest_length + end_off = start_off + header.packet_length buf.seek(end_off) try: @@ -978,7 +983,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "key_unavailable"}, + data={ + "trigger": "key_unavailable", + "raw": {"length": header.packet_length}, + }, ) # If a client receives HANDSHAKE or 1-RTT packets before it has @@ -997,15 +1005,18 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._quic_logger.log_event( category="transport", event="packet_dropped", - data={"trigger": "payload_decrypt_error"}, + data={ + "trigger": "payload_decrypt_error", + "raw": {"length": header.packet_length}, + }, ) continue # check reserved bits - if header.is_long_header: - reserved_mask = 0x0C - else: + if header.packet_type == QuicPacketType.ONE_RTT: reserved_mask = 0x18 + else: + reserved_mask = 0x0C if plain_header[0] & reserved_mask: self.close( error_code=QuicErrorCode.PROTOCOL_VIOLATION, @@ -1031,7 +1042,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non "dcid": dump_cid(header.destination_cid), "scid": dump_cid(header.source_cid), }, - "raw": {"length": end_off - start_off}, + "raw": {"length": header.packet_length}, }, ) @@ -1053,7 +1064,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non self._set_state(QuicConnectionState.CONNECTED) # update spin bit - if not header.is_long_header and packet_number > self._spin_highest_pn: + if ( + header.packet_type == QuicPacketType.ONE_RTT + and packet_number > self._spin_highest_pn + ): spin_bit = get_spin_bit(plain_header[0]) if self._is_client: self._spin_bit = not spin_bit @@ -2802,10 +2816,10 @@ def _write_application( if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ONE_RTT] crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT] - packet_type = PACKET_TYPE_ONE_RTT + packet_type = QuicPacketType.ONE_RTT elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ZERO_RTT] - packet_type = PACKET_TYPE_ZERO_RTT + packet_type = QuicPacketType.ZERO_RTT else: return space = self._spaces[tls.Epoch.ONE_RTT] @@ -2977,9 +2991,9 @@ def _write_handshake( while True: if epoch == tls.Epoch.INITIAL: - packet_type = PACKET_TYPE_INITIAL + packet_type = QuicPacketType.INITIAL else: - packet_type = PACKET_TYPE_HANDSHAKE + packet_type = QuicPacketType.HANDSHAKE builder.start_packet(packet_type, crypto) # ACK diff --git a/src/aioquic/quic/logger.py b/src/aioquic/quic/logger.py index 8deede679..6cdbc961a 100644 --- a/src/aioquic/quic/logger.py +++ b/src/aioquic/quic/logger.py @@ -7,24 +7,20 @@ from ..h3.events import Headers from .packet import ( - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_MASK, - PACKET_TYPE_ONE_RTT, - PACKET_TYPE_RETRY, - PACKET_TYPE_ZERO_RTT, QuicFrameType, + QuicPacketType, QuicStreamFrame, QuicTransportParameters, ) from .rangeset import RangeSet PACKET_TYPE_NAMES = { - PACKET_TYPE_INITIAL: "initial", - PACKET_TYPE_HANDSHAKE: "handshake", - PACKET_TYPE_ZERO_RTT: "0RTT", - PACKET_TYPE_ONE_RTT: "1RTT", - PACKET_TYPE_RETRY: "retry", + QuicPacketType.INITIAL: "initial", + QuicPacketType.HANDSHAKE: "handshake", + QuicPacketType.ZERO_RTT: "0RTT", + QuicPacketType.ONE_RTT: "1RTT", + QuicPacketType.RETRY: "retry", + QuicPacketType.VERSION_NEGOTIATION: "version_negotiation", } QLOG_VERSION = "0.3" @@ -212,8 +208,8 @@ def encode_transport_parameters( data[param_name] = param_value return data - def packet_type(self, packet_type: int) -> str: - return PACKET_TYPE_NAMES.get(packet_type & PACKET_TYPE_MASK, "1RTT") + def packet_type(self, packet_type: QuicPacketType) -> str: + return PACKET_TYPE_NAMES[packet_type] # HTTP/3 diff --git a/src/aioquic/quic/packet.py b/src/aioquic/quic/packet.py index 60c639b39..7d619ed04 100644 --- a/src/aioquic/quic/packet.py +++ b/src/aioquic/quic/packet.py @@ -2,7 +2,7 @@ import ipaddress import os from dataclasses import dataclass -from enum import IntEnum +from enum import Enum, IntEnum from typing import List, Optional, Tuple from cryptography.hazmat.primitives.ciphers.aead import AESGCM @@ -14,13 +14,6 @@ PACKET_FIXED_BIT = 0x40 PACKET_SPIN_BIT = 0x20 -PACKET_TYPE_INITIAL = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x00 -PACKET_TYPE_ZERO_RTT = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x10 -PACKET_TYPE_HANDSHAKE = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x20 -PACKET_TYPE_RETRY = PACKET_LONG_HEADER | PACKET_FIXED_BIT | 0x30 -PACKET_TYPE_ONE_RTT = PACKET_FIXED_BIT -PACKET_TYPE_MASK = 0xF0 - CONNECTION_ID_MAX_SIZE = 20 PACKET_NUMBER_MAX_SIZE = 4 RETRY_AEAD_KEY_DRAFT_29 = binascii.unhexlify("ccce187ed09a09d05728155a6cb96be1") @@ -51,6 +44,31 @@ class QuicErrorCode(IntEnum): CRYPTO_ERROR = 0x100 +class QuicPacketType(Enum): + INITIAL = 0 + ZERO_RTT = 1 + HANDSHAKE = 2 + RETRY = 3 + VERSION_NEGOTIATION = 4 + ONE_RTT = 5 + + +# For backwards compatibility only, use `QuicPacketType` in new code. +PACKET_TYPE_INITIAL = QuicPacketType.INITIAL + +# QUIC version 1 +# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 +PACKET_LONG_TYPE_ENCODE_VERSION_1 = { + QuicPacketType.INITIAL: 0, + QuicPacketType.ZERO_RTT: 1, + QuicPacketType.HANDSHAKE: 2, + QuicPacketType.RETRY: 3, +} +PACKET_LONG_TYPE_DECODE_VERSION_1 = dict( + (v, i) for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items() +) + + class QuicProtocolVersion(IntEnum): NEGOTIATION = 0 VERSION_1 = 0x00000001 @@ -62,14 +80,29 @@ class QuicProtocolVersion(IntEnum): @dataclass class QuicHeader: - is_long_header: bool version: Optional[int] - packet_type: int + "The protocol version. Only present in long header packets." + + packet_type: QuicPacketType + "The type of the packet." + + packet_length: int + "The total length of the packet, in bytes." + destination_cid: bytes + "The destination connection ID." + source_cid: bytes - token: bytes = b"" - integrity_tag: bytes = b"" - rest_length: int = 0 + "The destination connection ID." + + token: bytes + "The address verification token. Only present in `INITIAL` and `RETRY` packets." + + integrity_tag: bytes + "The retry integrity tag. Only present in `RETRY` packets." + + supported_versions: List[int] + "Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets." def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int: @@ -134,12 +167,17 @@ def is_long_header(first_byte: int) -> bool: def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> QuicHeader: - first_byte = buf.pull_uint8() + packet_start = buf.tell() + version = None integrity_tag = b"" + supported_versions = [] token = b"" + + first_byte = buf.pull_uint8() if is_long_header(first_byte): - # long header packet + # Long Header Packets. + # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 version = buf.pull_uint32() destination_cid_length = buf.pull_uint8() @@ -155,56 +193,58 @@ def pull_quic_header(buf: Buffer, host_cid_length: Optional[int] = None) -> Quic source_cid = buf.pull_bytes(source_cid_length) if version == QuicProtocolVersion.NEGOTIATION: - # version negotiation - packet_type = None - rest_length = buf.capacity - buf.tell() + # Version Negotiation Packet. + # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1 + packet_type = QuicPacketType.VERSION_NEGOTIATION + while not buf.eof(): + supported_versions.append(buf.pull_uint32()) + packet_end = buf.tell() else: if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") - packet_type = first_byte & PACKET_TYPE_MASK - if packet_type == PACKET_TYPE_INITIAL: + packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[(first_byte & 0x30) >> 4] + if packet_type == QuicPacketType.INITIAL: token_length = buf.pull_uint_var() token = buf.pull_bytes(token_length) rest_length = buf.pull_uint_var() - elif packet_type == PACKET_TYPE_RETRY: + elif packet_type == QuicPacketType.ZERO_RTT: + rest_length = buf.pull_uint_var() + elif packet_type == QuicPacketType.HANDSHAKE: + rest_length = buf.pull_uint_var() + else: token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE token = buf.pull_bytes(token_length) integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE) rest_length = 0 - else: - rest_length = buf.pull_uint_var() - # check remainder length - if rest_length > buf.capacity - buf.tell(): + # Check remainder length. + packet_end = buf.tell() + rest_length + if packet_end > buf.capacity: raise ValueError("Packet payload is truncated") - return QuicHeader( - is_long_header=True, - version=version, - packet_type=packet_type, - destination_cid=destination_cid, - source_cid=source_cid, - token=token, - integrity_tag=integrity_tag, - rest_length=rest_length, - ) else: - # short header packet + # Short Header Packets. + # https://datatracker.ietf.org/doc/html/rfc9000#section-17.3 if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") - packet_type = first_byte & PACKET_TYPE_MASK + version = None + packet_type = QuicPacketType.ONE_RTT destination_cid = buf.pull_bytes(host_cid_length) - return QuicHeader( - is_long_header=False, - version=None, - packet_type=packet_type, - destination_cid=destination_cid, - source_cid=b"", - token=b"", - rest_length=buf.capacity - buf.tell(), - ) + source_cid = b"" + packet_end = buf.capacity + + return QuicHeader( + version=version, + packet_type=packet_type, + packet_length=packet_end - packet_start, + destination_cid=destination_cid, + source_cid=source_cid, + token=token, + integrity_tag=integrity_tag, + supported_versions=supported_versions, + ) def encode_quic_retry( @@ -221,7 +261,11 @@ def encode_quic_retry( + len(retry_token) + RETRY_INTEGRITY_TAG_SIZE ) - buf.push_uint8(PACKET_TYPE_RETRY) + buf.push_uint8( + PACKET_LONG_HEADER + | PACKET_FIXED_BIT + | PACKET_LONG_TYPE_ENCODE_VERSION_1[QuicPacketType.RETRY] << 4 + ) buf.push_uint32(version) buf.push_uint8(len(destination_cid)) buf.push_bytes(destination_cid) diff --git a/src/aioquic/quic/packet_builder.py b/src/aioquic/quic/packet_builder.py index 77f7ec582..cd2740346 100644 --- a/src/aioquic/quic/packet_builder.py +++ b/src/aioquic/quic/packet_builder.py @@ -9,12 +9,12 @@ from .packet import ( NON_ACK_ELICITING_FRAME_TYPES, NON_IN_FLIGHT_FRAME_TYPES, + PACKET_FIXED_BIT, + PACKET_LONG_HEADER, + PACKET_LONG_TYPE_ENCODE_VERSION_1, PACKET_NUMBER_MAX_SIZE, - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_MASK, QuicFrameType, - is_long_header, + QuicPacketType, ) PACKET_LENGTH_SEND_SIZE = 2 @@ -36,7 +36,7 @@ class QuicSentPacket: is_ack_eliciting: bool is_crypto_packet: bool packet_number: int - packet_type: int + packet_type: QuicPacketType sent_time: Optional[float] = None sent_bytes: int = 0 @@ -92,10 +92,9 @@ def __init__( self._header_size = 0 self._packet: Optional[QuicSentPacket] = None self._packet_crypto: Optional[CryptoPair] = None - self._packet_long_header = False self._packet_number = packet_number self._packet_start = 0 - self._packet_type = 0 + self._packet_type: Optional[QuicPacketType] = None self._buffer = Buffer(max_datagram_size) self._buffer_capacity = max_datagram_size @@ -182,10 +181,16 @@ def start_frame( self._packet.delivery_handlers.append((handler, handler_args)) return self._buffer - def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: + def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None: """ Starts a new packet. """ + assert packet_type in ( + QuicPacketType.INITIAL, + QuicPacketType.HANDSHAKE, + QuicPacketType.ZERO_RTT, + QuicPacketType.ONE_RTT, + ), "Invalid packet type" buf = self._buffer # finish previous datagram @@ -215,10 +220,9 @@ def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: self._datagram_init = False # calculate header size - packet_long_header = is_long_header(packet_type) - if packet_long_header: + if packet_type != QuicPacketType.ONE_RTT: header_size = 11 + len(self._peer_cid) + len(self._host_cid) - if (packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: + if packet_type == QuicPacketType.INITIAL: token_length = len(self._peer_token) header_size += size_uint_var(token_length) + token_length else: @@ -229,9 +233,9 @@ def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: raise QuicPacketBuilderStop # determine ack epoch - if packet_type == PACKET_TYPE_INITIAL: + if packet_type == QuicPacketType.INITIAL: epoch = Epoch.INITIAL - elif packet_type == PACKET_TYPE_HANDSHAKE: + elif packet_type == QuicPacketType.HANDSHAKE: epoch = Epoch.HANDSHAKE else: epoch = Epoch.ONE_RTT @@ -246,7 +250,6 @@ def start_packet(self, packet_type: int, crypto: CryptoPair) -> None: packet_type=packet_type, ) self._packet_crypto = crypto - self._packet_long_header = packet_long_header self._packet_start = packet_start self._packet_type = packet_type self.quic_logger_frames = self._packet.quic_logger_frames @@ -272,7 +275,7 @@ def _end_packet(self) -> None: # 14.1. if ( (self._is_client or self._packet.is_ack_eliciting) - and self._packet_type == PACKET_TYPE_INITIAL + and self._packet_type == QuicPacketType.INITIAL and self.remaining_flight_space and self.remaining_flight_space > padding_size ): @@ -291,7 +294,7 @@ def _end_packet(self) -> None: ) # write header - if self._packet_long_header: + if self._packet_type != QuicPacketType.ONE_RTT: length = ( packet_size - self._header_size @@ -300,13 +303,18 @@ def _end_packet(self) -> None: ) buf.seek(self._packet_start) - buf.push_uint8(self._packet_type | (PACKET_NUMBER_SEND_SIZE - 1)) + buf.push_uint8( + PACKET_LONG_HEADER + | PACKET_FIXED_BIT + | PACKET_LONG_TYPE_ENCODE_VERSION_1[self._packet_type] << 4 + | (PACKET_NUMBER_SEND_SIZE - 1) + ) buf.push_uint32(self._version) buf.push_uint8(len(self._peer_cid)) buf.push_bytes(self._peer_cid) buf.push_uint8(len(self._host_cid)) buf.push_bytes(self._host_cid) - if (self._packet_type & PACKET_TYPE_MASK) == PACKET_TYPE_INITIAL: + if self._packet_type == QuicPacketType.INITIAL: buf.push_uint_var(len(self._peer_token)) buf.push_bytes(self._peer_token) buf.push_uint16(length | 0x4000) @@ -314,7 +322,7 @@ def _end_packet(self) -> None: else: buf.seek(self._packet_start) buf.push_uint8( - self._packet_type + PACKET_FIXED_BIT | (self._spin_bit << 5) | (self._packet_crypto.key_phase << 2) | (PACKET_NUMBER_SEND_SIZE - 1) @@ -338,7 +346,7 @@ def _end_packet(self) -> None: self._datagram_flight_bytes += self._packet.sent_bytes # short header packets cannot be coalesced, we need a new datagram - if not self._packet_long_header: + if self._packet_type == QuicPacketType.ONE_RTT: self._flush_current_datagram() self._packet_number += 1 diff --git a/tests/test_connection.py b/tests/test_connection.py index 97a1939c2..c031a9867 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -22,9 +22,9 @@ from aioquic.quic.crypto import CryptoPair from aioquic.quic.logger import QuicLogger from aioquic.quic.packet import ( - PACKET_TYPE_INITIAL, QuicErrorCode, QuicFrameType, + QuicPacketType, QuicProtocolVersion, QuicTransportParameters, encode_quic_retry, @@ -838,7 +838,7 @@ def test_initial_that_is_too_small(self): crypto.setup_initial( client._peer_cid.cid, is_client=False, version=client._version ) - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) @@ -1200,7 +1200,7 @@ def encrypt_packet(plain_header, plain_payload, packet_number): crypto.encrypt_packet = encrypt_packet - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) @@ -1230,7 +1230,7 @@ def test_receive_datagram_wrong_version(self): crypto.setup_initial( client._peer_cid.cid, is_client=False, version=client._version ) - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) buf = builder.start_frame(QuicFrameType.PADDING) buf.push_bytes(bytes(builder.remaining_flight_space)) @@ -3007,7 +3007,7 @@ def test_write_connection_close_early(self): ) crypto = CryptoPair() crypto.setup_initial(client.host_cid, is_client=True, version=client._version) - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) client._write_connection_close_frame( builder=builder, epoch=tls.Epoch.INITIAL, diff --git a/tests/test_packet.py b/tests/test_packet.py index 8f89d3e4b..6e2cd33e9 100644 --- a/tests/test_packet.py +++ b/tests/test_packet.py @@ -4,8 +4,7 @@ from aioquic.buffer import Buffer, BufferReadError from aioquic.quic import packet from aioquic.quic.packet import ( - PACKET_TYPE_INITIAL, - PACKET_TYPE_RETRY, + QuicPacketType, QuicPreferredAddress, QuicProtocolVersion, QuicTransportParameters, @@ -54,14 +53,13 @@ def test_pull_empty(self): def test_pull_initial_client(self): buf = Buffer(data=load("initial_client.bin")) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) - self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) + self.assertEqual(header.packet_type, QuicPacketType.INITIAL) + self.assertEqual(header.packet_length, 1280) self.assertEqual(header.destination_cid, binascii.unhexlify("858b39368b8e3c6e")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 1262) self.assertEqual(buf.tell(), 18) def test_pull_initial_client_truncated(self): @@ -73,25 +71,24 @@ def test_pull_initial_client_truncated(self): def test_pull_initial_server(self): buf = Buffer(data=load("initial_server.bin")) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) - self.assertEqual(header.packet_type, PACKET_TYPE_INITIAL) + self.assertEqual(header.packet_type, QuicPacketType.INITIAL) + self.assertEqual(header.packet_length, 202) self.assertEqual(header.destination_cid, b"") self.assertEqual(header.source_cid, binascii.unhexlify("195c68344e28d479")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 184) self.assertEqual(buf.tell(), 18) - def test_pull_retry(self): + def test_pull_retry_v1(self): original_destination_cid = binascii.unhexlify("fbbd219b7363b64b") data = load("retry.bin") buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.VERSION_1) - self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) + self.assertEqual(header.packet_type, QuicPacketType.RETRY) + self.assertEqual(header.packet_length, 125) self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e")) self.assertEqual( header.source_cid, @@ -108,7 +105,6 @@ def test_pull_retry(self): self.assertEqual( header.integrity_tag, binascii.unhexlify("4620aafd42f1d630588b27575a12da5c") ) - self.assertEqual(header.rest_length, 0) self.assertEqual(buf.tell(), 125) # check integrity @@ -135,9 +131,9 @@ def test_pull_retry_draft_29(self): data = load("retry_draft_29.bin") buf = Buffer(data=data) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.DRAFT_29) - self.assertEqual(header.packet_type, PACKET_TYPE_RETRY) + self.assertEqual(header.packet_type, QuicPacketType.RETRY) + self.assertEqual(header.packet_length, 125) self.assertEqual(header.destination_cid, binascii.unhexlify("e9d146d8d14cb28e")) self.assertEqual( header.source_cid, @@ -154,7 +150,6 @@ def test_pull_retry_draft_29(self): self.assertEqual( header.integrity_tag, binascii.unhexlify("e65b170337b611270f10f4e633b6f51b") ) - self.assertEqual(header.rest_length, 0) self.assertEqual(buf.tell(), 125) # check integrity @@ -178,20 +173,17 @@ def test_pull_retry_draft_29(self): def test_pull_version_negotiation(self): buf = Buffer(data=load("version_negotiation.bin")) header = pull_quic_header(buf, host_cid_length=8) - self.assertTrue(header.is_long_header) self.assertEqual(header.version, QuicProtocolVersion.NEGOTIATION) - self.assertEqual(header.packet_type, None) + self.assertEqual(header.packet_type, QuicPacketType.VERSION_NEGOTIATION) + self.assertEqual(header.packet_length, 31) self.assertEqual(header.destination_cid, binascii.unhexlify("9aac5a49ba87a849")) self.assertEqual(header.source_cid, binascii.unhexlify("f92f4336fa951ba1")) self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 8) - self.assertEqual(buf.tell(), 23) - - versions = [] - while not buf.eof(): - versions.append(buf.pull_uint32()) - self.assertEqual(versions, [0x45474716, QuicProtocolVersion.VERSION_1]) + self.assertEqual( + header.supported_versions, [0x45474716, QuicProtocolVersion.VERSION_1] + ) + self.assertEqual(buf.tell(), 31) def test_pull_long_header_dcid_too_long(self): buf = Buffer( @@ -229,14 +221,13 @@ def test_pull_long_header_too_short(self): def test_pull_short_header(self): buf = Buffer(data=load("short_header.bin")) header = pull_quic_header(buf, host_cid_length=8) - self.assertFalse(header.is_long_header) self.assertEqual(header.version, None) - self.assertEqual(header.packet_type, 0x50) + self.assertEqual(header.packet_type, QuicPacketType.ONE_RTT) + self.assertEqual(header.packet_length, 21) self.assertEqual(header.destination_cid, binascii.unhexlify("f45aa7b59c0e1ad6")) self.assertEqual(header.source_cid, b"") self.assertEqual(header.token, b"") self.assertEqual(header.integrity_tag, b"") - self.assertEqual(header.rest_length, 12) self.assertEqual(buf.tell(), 9) def test_pull_short_header_no_fixed_bit(self): diff --git a/tests/test_packet_builder.py b/tests/test_packet_builder.py index be1562d9f..16f6cc103 100644 --- a/tests/test_packet_builder.py +++ b/tests/test_packet_builder.py @@ -3,13 +3,7 @@ from aioquic.quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE from aioquic.quic.crypto import CryptoPair -from aioquic.quic.packet import ( - PACKET_TYPE_HANDSHAKE, - PACKET_TYPE_INITIAL, - PACKET_TYPE_ONE_RTT, - QuicFrameType, - QuicProtocolVersion, -) +from aioquic.quic.packet import QuicFrameType, QuicPacketType, QuicProtocolVersion from aioquic.quic.packet_builder import ( QuicPacketBuilder, QuicPacketBuilderStop, @@ -48,7 +42,7 @@ def test_long_header_empty(self): builder = create_builder() crypto = create_crypto() - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) self.assertTrue(builder.packet_is_empty) @@ -65,14 +59,14 @@ def test_long_header_initial_client(self): crypto = create_crypto() # INITIAL, fully padded - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -87,7 +81,7 @@ def test_long_header_initial_client(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ) ], @@ -101,21 +95,21 @@ def test_long_header_initial_client_2(self): crypto = create_crypto() # INITIAL, full length - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # INITIAL, full length - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -130,7 +124,7 @@ def test_long_header_initial_client_2(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ), QuicSentPacket( @@ -139,7 +133,7 @@ def test_long_header_initial_client_2(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ), ], @@ -153,7 +147,7 @@ def test_long_header_initial_server(self): crypto = create_crypto() # INITIAL with ACK + CRYPTO + PADDING - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.ACK) @@ -164,11 +158,11 @@ def test_long_header_initial_server(self): self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # HANDSHAKE with CRYPTO - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertEqual(builder.remaining_flight_space, 1157) buf = builder.start_frame(QuicFrameType.CRYPTO) @@ -176,7 +170,7 @@ def test_long_header_initial_server(self): self.assertFalse(builder.packet_is_empty) # HANDSHAKE with CRYPTO - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertEqual(builder.remaining_flight_space, 1157) buf = builder.start_frame(QuicFrameType.CRYPTO) @@ -184,7 +178,7 @@ def test_long_header_initial_server(self): self.assertFalse(builder.packet_is_empty) # HANDSHAKE, empty - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -199,7 +193,7 @@ def test_long_header_initial_server(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ), QuicSentPacket( @@ -208,7 +202,7 @@ def test_long_header_initial_server(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, - packet_type=PACKET_TYPE_HANDSHAKE, + packet_type=QuicPacketType.HANDSHAKE, sent_bytes=1200, ), QuicSentPacket( @@ -217,7 +211,7 @@ def test_long_header_initial_server(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=2, - packet_type=PACKET_TYPE_HANDSHAKE, + packet_type=QuicPacketType.HANDSHAKE, sent_bytes=844, ), ], @@ -231,18 +225,18 @@ def test_long_header_initial_server_without_handshake(self): crypto = create_crypto() # INITIAL - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(100)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # HANDSHAKE, empty - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -257,7 +251,7 @@ def test_long_header_initial_server_without_handshake(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ) ], @@ -275,7 +269,7 @@ def test_long_header_ping_only(self): crypto = create_crypto() # HANDSHAKE, with only a PING frame - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) builder.start_frame(QuicFrameType.PING) self.assertFalse(builder.packet_is_empty) @@ -292,7 +286,7 @@ def test_long_header_ping_only(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_HANDSHAKE, + packet_type=QuicPacketType.HANDSHAKE, sent_bytes=45, ) ], @@ -303,25 +297,25 @@ def test_long_header_then_short_header(self): crypto = create_crypto() # INITIAL, full length - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # INITIAL, empty - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertTrue(builder.packet_is_empty) # ONE_RTT, full length - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) buf = builder.start_frame(QuicFrameType.STREAM_BASE) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) # ONE_RTT, empty - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertTrue(builder.packet_is_empty) # check datagrams @@ -338,7 +332,7 @@ def test_long_header_then_short_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ), QuicSentPacket( @@ -347,7 +341,7 @@ def test_long_header_then_short_header(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=1, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1200, ), ], @@ -361,14 +355,14 @@ def test_long_header_then_long_header(self): crypto = create_crypto() # INITIAL - builder.start_packet(PACKET_TYPE_INITIAL, crypto) + builder.start_packet(QuicPacketType.INITIAL, crypto) self.assertEqual(builder.remaining_flight_space, 1156) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(199)) self.assertFalse(builder.packet_is_empty) # HANDSHAKE - builder.start_packet(PACKET_TYPE_HANDSHAKE, crypto) + builder.start_packet(QuicPacketType.HANDSHAKE, crypto) self.assertEqual(builder.remaining_flight_space, 1157) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(299)) @@ -376,7 +370,7 @@ def test_long_header_then_long_header(self): self.assertEqual(builder.remaining_flight_space, 857) # ONE_RTT - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 830) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(299)) @@ -396,7 +390,7 @@ def test_long_header_then_long_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1200, ), QuicSentPacket( @@ -405,7 +399,7 @@ def test_long_header_then_long_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, - packet_type=PACKET_TYPE_HANDSHAKE, + packet_type=QuicPacketType.HANDSHAKE, sent_bytes=343, ), QuicSentPacket( @@ -414,7 +408,7 @@ def test_long_header_then_long_header(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=2, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=327, ), ], @@ -427,7 +421,7 @@ def test_short_header_empty(self): builder = create_builder() crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) self.assertTrue(builder.packet_is_empty) @@ -444,7 +438,7 @@ def test_short_header_padding(self): crypto = create_crypto() # ONE_RTT, full length - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) @@ -463,7 +457,7 @@ def test_short_header_padding(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1200, ) ], @@ -481,14 +475,14 @@ def test_short_header_max_flight_bytes(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 973) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams @@ -504,7 +498,7 @@ def test_short_header_max_flight_bytes(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1000, ), ], @@ -525,7 +519,7 @@ def test_short_header_max_flight_bytes_zero(self): crypto = create_crypto() with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams @@ -546,12 +540,12 @@ def test_short_header_max_flight_bytes_zero_ack(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) buf = builder.start_frame(QuicFrameType.ACK) buf.push_bytes(bytes(64)) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.CRYPTO) # check datagrams @@ -567,7 +561,7 @@ def test_short_header_max_flight_bytes_zero_ack(self): is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=92, ), ], @@ -586,7 +580,7 @@ def test_short_header_max_total_bytes_1(self): crypto = create_crypto() with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() @@ -605,14 +599,14 @@ def test_short_header_max_total_bytes_2(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 773) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() @@ -627,7 +621,7 @@ def test_short_header_max_total_bytes_2(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=800, ) ], @@ -642,20 +636,20 @@ def test_short_header_max_total_bytes_3(self): crypto = create_crypto() - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 1173) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) self.assertEqual(builder.remaining_flight_space, 773) buf = builder.start_frame(QuicFrameType.CRYPTO) buf.push_bytes(bytes(builder.remaining_flight_space)) self.assertFalse(builder.packet_is_empty) with self.assertRaises(QuicPacketBuilderStop): - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) # check datagrams datagrams, packets = builder.flush() @@ -671,7 +665,7 @@ def test_short_header_max_total_bytes_3(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1200, ), QuicSentPacket( @@ -680,7 +674,7 @@ def test_short_header_max_total_bytes_3(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=1, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=800, ), ], @@ -698,7 +692,7 @@ def test_short_header_ping_only(self): crypto = create_crypto() # HANDSHAKE, with only a PING frame - builder.start_packet(PACKET_TYPE_ONE_RTT, crypto) + builder.start_packet(QuicPacketType.ONE_RTT, crypto) builder.start_frame(QuicFrameType.PING) self.assertFalse(builder.packet_is_empty) @@ -715,7 +709,7 @@ def test_short_header_ping_only(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=29, ) ], diff --git a/tests/test_recovery_cubic.py b/tests/test_recovery_cubic.py index c71c7c8a5..6f654c304 100644 --- a/tests/test_recovery_cubic.py +++ b/tests/test_recovery_cubic.py @@ -9,7 +9,7 @@ CubicCongestionControl, better_cube_root, ) -from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT +from aioquic.quic.packet import QuicPacketType from aioquic.quic.packet_builder import QuicSentPacket from aioquic.quic.rangeset import RangeSet from aioquic.quic.recovery import QuicPacketRecovery, QuicPacketSpace @@ -58,7 +58,7 @@ def test_on_ack_received_ack_eliciting(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) @@ -94,7 +94,7 @@ def test_on_ack_received_non_ack_eliciting(self): is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=123.45, ) @@ -130,7 +130,7 @@ def test_on_packet_lost_crypto(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1280, sent_time=0.0, ) @@ -153,7 +153,7 @@ def test_packet_expired(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) @@ -246,7 +246,7 @@ def test_reset_idle(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=10.0, ) @@ -299,7 +299,7 @@ def test_reno_friendly_region(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) @@ -337,7 +337,7 @@ def test_convex_region(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) @@ -379,7 +379,7 @@ def test_concave_region(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) @@ -418,7 +418,7 @@ def test_increasing_rtt_exiting_slow_start(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=200.0, ) @@ -463,7 +463,7 @@ def test_packet_lost(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=200.0, ) @@ -474,7 +474,7 @@ def test_packet_lost(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=240.0, ) @@ -522,7 +522,7 @@ def test_lost_with_W_max(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=200.0, ) @@ -564,7 +564,7 @@ def test_cwnd_target(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) diff --git a/tests/test_recovery_reno.py b/tests/test_recovery_reno.py index 87822a74a..041651591 100644 --- a/tests/test_recovery_reno.py +++ b/tests/test_recovery_reno.py @@ -2,7 +2,7 @@ from unittest import TestCase from aioquic import tls -from aioquic.quic.packet import PACKET_TYPE_INITIAL, PACKET_TYPE_ONE_RTT +from aioquic.quic.packet import QuicPacketType from aioquic.quic.packet_builder import QuicSentPacket from aioquic.quic.rangeset import RangeSet from aioquic.quic.recovery import QuicPacketRecovery, QuicPacketSpace @@ -41,7 +41,7 @@ def test_on_ack_received_ack_eliciting(self): is_ack_eliciting=True, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=0.0, ) @@ -77,7 +77,7 @@ def test_on_ack_received_non_ack_eliciting(self): is_ack_eliciting=False, is_crypto_packet=False, packet_number=0, - packet_type=PACKET_TYPE_ONE_RTT, + packet_type=QuicPacketType.ONE_RTT, sent_bytes=1280, sent_time=123.45, ) @@ -113,7 +113,7 @@ def test_on_packet_lost_crypto(self): is_ack_eliciting=True, is_crypto_packet=True, packet_number=0, - packet_type=PACKET_TYPE_INITIAL, + packet_type=QuicPacketType.INITIAL, sent_bytes=1280, sent_time=0.0, )