Skip to content

Commit

Permalink
Rework packet encoding to support different protocol versions
Browse files Browse the repository at this point in the history
The wire representation of packet types changes in QUIC version 2.
Instead of carrying around the wire representation of a packet type, we
introduce a `QuicPacketType` enum.
  • Loading branch information
jlaine committed Jun 23, 2024
1 parent ff3281f commit bd3497c
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 241 deletions.
4 changes: 2 additions & 2 deletions src/aioquic/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..quic.configuration import SMALLEST_MAX_DATAGRAM_SIZE, QuicConfiguration
from ..quic.connection import NetworkAddress, QuicConnection
from ..quic.packet import (
PACKET_TYPE_INITIAL,
QuicPacketType,
encode_quic_retry,
encode_quic_version_negotiation,
pull_quic_header,
Expand Down Expand Up @@ -86,7 +86,7 @@ def datagram_received(self, data: Union[bytes, Text], addr: NetworkAddress) -> N
if (
protocol is None
and len(data) >= SMALLEST_MAX_DATAGRAM_SIZE
and header.packet_type == PACKET_TYPE_INITIAL
and header.packet_type == QuicPacketType.INITIAL
):
# retry
if self._retry is not None:
Expand Down
106 changes: 60 additions & 46 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,18 @@
from .packet import (
CONNECTION_ID_MAX_SIZE,
NON_ACK_ELICITING_FRAME_TYPES,
PACKET_TYPE_HANDSHAKE,
PACKET_TYPE_INITIAL,
PACKET_TYPE_ONE_RTT,
PACKET_TYPE_RETRY,
PACKET_TYPE_ZERO_RTT,
PROBING_FRAME_TYPES,
RETRY_INTEGRITY_TAG_SIZE,
STATELESS_RESET_TOKEN_SIZE,
QuicErrorCode,
QuicFrameType,
QuicPacketType,
QuicProtocolVersion,
QuicStreamFrame,
QuicTransportParameters,
get_retry_integrity_tag,
get_spin_bit,
is_draft_version,
is_long_header,
pull_ack_frame,
pull_quic_header,
pull_quic_transport_parameters,
Expand Down Expand Up @@ -122,12 +117,12 @@ def dump_cid(cid: bytes) -> str:
return binascii.hexlify(cid).decode("ascii")


def get_epoch(packet_type: int) -> tls.Epoch:
if packet_type == PACKET_TYPE_INITIAL:
def get_epoch(packet_type: QuicPacketType) -> tls.Epoch:
if packet_type == QuicPacketType.INITIAL:
return tls.Epoch.INITIAL
elif packet_type == PACKET_TYPE_ZERO_RTT:
elif packet_type == QuicPacketType.ZERO_RTT:
return tls.Epoch.ZERO_RTT
elif packet_type == PACKET_TYPE_HANDSHAKE:
elif packet_type == QuicPacketType.HANDSHAKE:
return tls.Epoch.HANDSHAKE
else:
return tls.Epoch.ONE_RTT
Expand Down Expand Up @@ -544,10 +539,10 @@ def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]:
epoch_packet_types = []
if not self._handshake_confirmed:
epoch_packet_types += [
(tls.Epoch.INITIAL, PACKET_TYPE_INITIAL),
(tls.Epoch.HANDSHAKE, PACKET_TYPE_HANDSHAKE),
(tls.Epoch.INITIAL, QuicPacketType.INITIAL),
(tls.Epoch.HANDSHAKE, QuicPacketType.HANDSHAKE),
]
epoch_packet_types.append((tls.Epoch.ONE_RTT, PACKET_TYPE_ONE_RTT))
epoch_packet_types.append((tls.Epoch.ONE_RTT, QuicPacketType.ONE_RTT))
for epoch, packet_type in epoch_packet_types:
crypto = self._cryptos[epoch]
if crypto.send.is_valid():
Expand Down Expand Up @@ -619,9 +614,9 @@ def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]:
packet.packet_type
),
"scid": (
dump_cid(self.host_cid)
if is_long_header(packet.packet_type)
else ""
""
if packet.packet_type == QuicPacketType.ONE_RTT
else dump_cid(self.host_cid)
),
"dcid": dump_cid(self._peer_cid.cid),
},
Expand Down Expand Up @@ -791,7 +786,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
# contained in a datagram smaller than 1200 bytes.
if (
not self._is_client
and header.packet_type == PACKET_TYPE_INITIAL
and header.packet_type == QuicPacketType.INITIAL
and len(data) < SMALLEST_MAX_DATAGRAM_SIZE
):
if self._quic_logger is not None:
Expand All @@ -800,7 +795,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
event="packet_dropped",
data={
"trigger": "initial_packet_datagram_too_small",
"raw": {"length": buf.capacity - start_off},
"raw": {"length": header.packet_length},
},
)
return
Expand All @@ -812,56 +807,60 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
destination_cid_seq = connection_id.sequence_number
break
if (
self._is_client or header.packet_type == PACKET_TYPE_HANDSHAKE
self._is_client or header.packet_type == QuicPacketType.HANDSHAKE
) and destination_cid_seq is None:
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={"trigger": "unknown_connection_id"},
data={
"trigger": "unknown_connection_id",
"raw": {"length": header.packet_length},
},
)
return

# check protocol version
if (
self._is_client
and self._state == QuicConnectionState.FIRSTFLIGHT
and header.version == QuicProtocolVersion.NEGOTIATION
and header.packet_type == QuicPacketType.VERSION_NEGOTIATION
and not self._version_negotiation_count
):
# version negotiation
versions = []
while not buf.eof():
versions.append(buf.pull_uint32())
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="packet_received",
data={
"frames": [],
"header": {
"packet_type": "version_negotiation",
"packet_type": self._quic_logger.packet_type(
header.packet_type
),
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": buf.tell() - start_off},
"raw": {"length": header.packet_length},
},
)
if self._version in versions:
if self._version in header.supported_versions:
self._logger.warning(
"Version negotiation packet contains %s" % self._version
)
return
common = [
x for x in self._configuration.supported_versions if x in versions
x
for x in self._configuration.supported_versions
if x in header.supported_versions
]
chosen_version = common[0] if common else None
if self._quic_logger is not None:
self._quic_logger.log_event(
category="transport",
event="version_information",
data={
"server_versions": versions,
"server_versions": header.supported_versions,
"client_versions": self._configuration.supported_versions,
"chosen_version": chosen_version,
},
Expand Down Expand Up @@ -890,12 +889,15 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={"trigger": "unsupported_version"},
data={
"trigger": "unsupported_version",
"raw": {"length": header.packet_length},
},
)
return

# handle retry packet
if header.packet_type == PACKET_TYPE_RETRY:
if header.packet_type == QuicPacketType.RETRY:
if (
self._is_client
and not self._retry_count
Expand All @@ -920,7 +922,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
"scid": dump_cid(header.source_cid),
"dcid": dump_cid(header.destination_cid),
},
"raw": {"length": buf.tell() - start_off},
"raw": {"length": header.packet_length},
},
)

Expand All @@ -938,7 +940,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={"trigger": "unexpected_packet"},
data={
"trigger": "unexpected_packet",
"raw": {"length": header.packet_length},
},
)
return

Expand All @@ -948,7 +953,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
# server initialization
if not self._is_client and self._state == QuicConnectionState.FIRSTFLIGHT:
assert (
header.packet_type == PACKET_TYPE_INITIAL
header.packet_type == QuicPacketType.INITIAL
), "first packet must be INITIAL"
crypto_frame_required = True
self._network_paths = [network_path]
Expand All @@ -965,7 +970,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non

# decrypt packet
encrypted_off = buf.tell() - start_off
end_off = buf.tell() + header.rest_length
end_off = start_off + header.packet_length
buf.seek(end_off)

try:
Expand All @@ -978,7 +983,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={"trigger": "key_unavailable"},
data={
"trigger": "key_unavailable",
"raw": {"length": header.packet_length},
},
)

# If a client receives HANDSHAKE or 1-RTT packets before it has
Expand All @@ -997,15 +1005,18 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
self._quic_logger.log_event(
category="transport",
event="packet_dropped",
data={"trigger": "payload_decrypt_error"},
data={
"trigger": "payload_decrypt_error",
"raw": {"length": header.packet_length},
},
)
continue

# check reserved bits
if header.is_long_header:
reserved_mask = 0x0C
else:
if header.packet_type == QuicPacketType.ONE_RTT:
reserved_mask = 0x18
else:
reserved_mask = 0x0C
if plain_header[0] & reserved_mask:
self.close(
error_code=QuicErrorCode.PROTOCOL_VIOLATION,
Expand All @@ -1031,7 +1042,7 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
"dcid": dump_cid(header.destination_cid),
"scid": dump_cid(header.source_cid),
},
"raw": {"length": end_off - start_off},
"raw": {"length": header.packet_length},
},
)

Expand All @@ -1053,7 +1064,10 @@ def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> Non
self._set_state(QuicConnectionState.CONNECTED)

# update spin bit
if not header.is_long_header and packet_number > self._spin_highest_pn:
if (
header.packet_type == QuicPacketType.ONE_RTT
and packet_number > self._spin_highest_pn
):
spin_bit = get_spin_bit(plain_header[0])
if self._is_client:
self._spin_bit = not spin_bit
Expand Down Expand Up @@ -2802,10 +2816,10 @@ def _write_application(
if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid():
crypto = self._cryptos[tls.Epoch.ONE_RTT]
crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT]
packet_type = PACKET_TYPE_ONE_RTT
packet_type = QuicPacketType.ONE_RTT
elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid():
crypto = self._cryptos[tls.Epoch.ZERO_RTT]
packet_type = PACKET_TYPE_ZERO_RTT
packet_type = QuicPacketType.ZERO_RTT
else:
return
space = self._spaces[tls.Epoch.ONE_RTT]
Expand Down Expand Up @@ -2977,9 +2991,9 @@ def _write_handshake(

while True:
if epoch == tls.Epoch.INITIAL:
packet_type = PACKET_TYPE_INITIAL
packet_type = QuicPacketType.INITIAL
else:
packet_type = PACKET_TYPE_HANDSHAKE
packet_type = QuicPacketType.HANDSHAKE
builder.start_packet(packet_type, crypto)

# ACK
Expand Down
22 changes: 9 additions & 13 deletions src/aioquic/quic/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,20 @@

from ..h3.events import Headers
from .packet import (
PACKET_TYPE_HANDSHAKE,
PACKET_TYPE_INITIAL,
PACKET_TYPE_MASK,
PACKET_TYPE_ONE_RTT,
PACKET_TYPE_RETRY,
PACKET_TYPE_ZERO_RTT,
QuicFrameType,
QuicPacketType,
QuicStreamFrame,
QuicTransportParameters,
)
from .rangeset import RangeSet

PACKET_TYPE_NAMES = {
PACKET_TYPE_INITIAL: "initial",
PACKET_TYPE_HANDSHAKE: "handshake",
PACKET_TYPE_ZERO_RTT: "0RTT",
PACKET_TYPE_ONE_RTT: "1RTT",
PACKET_TYPE_RETRY: "retry",
QuicPacketType.INITIAL: "initial",
QuicPacketType.HANDSHAKE: "handshake",
QuicPacketType.ZERO_RTT: "0RTT",
QuicPacketType.ONE_RTT: "1RTT",
QuicPacketType.RETRY: "retry",
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
}
QLOG_VERSION = "0.3"

Expand Down Expand Up @@ -212,8 +208,8 @@ def encode_transport_parameters(
data[param_name] = param_value
return data

def packet_type(self, packet_type: int) -> str:
return PACKET_TYPE_NAMES.get(packet_type & PACKET_TYPE_MASK, "1RTT")
def packet_type(self, packet_type: QuicPacketType) -> str:
return PACKET_TYPE_NAMES[packet_type]

# HTTP/3

Expand Down
Loading

0 comments on commit bd3497c

Please sign in to comment.