Skip to content

Commit

Permalink
[web transport] add initial code for WebTransport support (aiortc#204)
Browse files Browse the repository at this point in the history
Add an initial API to handle WebTransport over HTTP/3 with support for datagrams, bidirectional streams and unidirectional streams.
  • Loading branch information
jlaine authored Jul 13, 2021
1 parent 62e9ba2 commit 59134f3
Show file tree
Hide file tree
Showing 5 changed files with 593 additions and 24 deletions.
6 changes: 6 additions & 0 deletions docs/h3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Events
.. autoclass:: H3Event
:members:

.. autoclass:: DatagramReceived
:members:

.. autoclass:: DataReceived
:members:

Expand All @@ -31,6 +34,9 @@ Events
.. autoclass:: PushPromiseReceived
:members:

.. autoclass:: WebTransportStreamDataReceived
:members:


Exceptions
----------
Expand Down
197 changes: 173 additions & 24 deletions src/aioquic/h3/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@

from aioquic.buffer import UINT_VAR_MAX_SIZE, Buffer, BufferReadError, encode_uint_var
from aioquic.h3.events import (
DatagramReceived,
DataReceived,
H3Event,
Headers,
HeadersReceived,
PushPromiseReceived,
WebTransportStreamDataReceived,
)
from aioquic.h3.exceptions import NoAvailablePushIDError
from aioquic.quic.connection import QuicConnection, stream_is_unidirectional
from aioquic.quic.events import QuicEvent, StreamDataReceived
from aioquic.quic.events import DatagramFrameReceived, QuicEvent, StreamDataReceived
from aioquic.quic.logger import QuicLoggerTrace

logger = logging.getLogger("http3")
Expand Down Expand Up @@ -58,6 +60,7 @@ class FrameType(IntEnum):
GOAWAY = 0x7
MAX_PUSH_ID = 0xD
DUPLICATE_PUSH = 0xE
WEBTRANSPORT_STREAM = 0x41


class HeadersState(Enum):
Expand All @@ -71,6 +74,8 @@ class Setting(IntEnum):
SETTINGS_MAX_HEADER_LIST_SIZE = 0x6
QPACK_BLOCKED_STREAMS = 0x7
SETTINGS_NUM_PLACEHOLDERS = 0x9
H3_DATAGRAM = 0x276
SETTINGS_ENABLE_WEBTRANSPORT = 0x2B603742

#  Dummy setting to check it is correctly ignored by the peer.
# https://tools.ietf.org/html/draft-ietf-quic-http-34#section-7.2.4.1
Expand All @@ -82,6 +87,7 @@ class StreamType(IntEnum):
PUSH = 1
QPACK_ENCODER = 2
QPACK_DECODER = 3
WEBTRANSPORT = 0x54


class ProtocolError(Exception):
Expand Down Expand Up @@ -310,6 +316,7 @@ def __init__(self, stream_id: int) -> None:
self.headers_recv_state: HeadersState = HeadersState.INITIAL
self.headers_send_state: HeadersState = HeadersState.INITIAL
self.push_id: Optional[int] = None
self.session_id: Optional[int] = None
self.stream_id = stream_id
self.stream_type: Optional[int] = None

Expand All @@ -321,9 +328,11 @@ class H3Connection:
:param quic: A :class:`~aioquic.connection.QuicConnection` instance.
"""

def __init__(self, quic: QuicConnection):
def __init__(self, quic: QuicConnection, enable_webtransport: bool = False) -> None:
# settings
self._max_table_capacity = 4096
self._blocked_streams = 16
self._enable_webtransport = enable_webtransport

self._is_client = quic.configuration.is_client
self._is_done = False
Expand Down Expand Up @@ -353,31 +362,66 @@ def __init__(self, quic: QuicConnection):

self._init_connection()

def create_webtransport_stream(
self, session_id: int, is_unidirectional: bool = False
) -> int:
"""
Create a WebTransport stream and return the stream ID.
:param session_id: The WebTransport session identifier.
:param is_unidirectional: Whether to create a unidirectional stream.
"""
if is_unidirectional:
stream_id = self._create_uni_stream(StreamType.WEBTRANSPORT)
self._quic.send_stream_data(stream_id, encode_uint_var(session_id))
else:
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(
stream_id,
encode_uint_var(FrameType.WEBTRANSPORT_STREAM)
+ encode_uint_var(session_id),
)
return stream_id

def handle_event(self, event: QuicEvent) -> List[H3Event]:
"""
Handle a QUIC event and return a list of HTTP events.
:param event: The QUIC event to handle.
"""
if isinstance(event, StreamDataReceived) and not self._is_done:
stream_id = event.stream_id
stream = self._get_or_create_stream(stream_id)

if not self._is_done:
try:
if stream_id % 4 == 0:
return self._receive_request_or_push_data(
stream, event.data, event.end_stream
)
elif stream_is_unidirectional(stream_id):
return self._receive_stream_data_uni(
stream, event.data, event.end_stream
)
if isinstance(event, StreamDataReceived):
stream_id = event.stream_id
stream = self._get_or_create_stream(stream_id)
if stream_id % 4 == 0:
return self._receive_request_or_push_data(
stream, event.data, event.end_stream
)
elif stream_is_unidirectional(stream_id):
return self._receive_stream_data_uni(
stream, event.data, event.end_stream
)
elif isinstance(event, DatagramFrameReceived):
return self._receive_datagram(event.data)
except ProtocolError as exc:
self._is_done = True
self._quic.close(
error_code=exc.error_code, reason_phrase=exc.reason_phrase
)

return []

def send_datagram(self, flow_id: int, data: bytes) -> None:
"""
Send a datagram for the specified flow.
:param flow_id: The flow ID.
:param data: The HTTP/3 datagram payload.
"""
self._quic.send_datagram_frame(encode_uint_var(flow_id) + data)

def send_push_promise(self, stream_id: int, headers: Headers) -> int:
"""
Send a push promise related to the specified stream.
Expand Down Expand Up @@ -517,6 +561,20 @@ def _get_or_create_stream(self, stream_id: int) -> H3Stream:
self._stream[stream_id] = H3Stream(stream_id)
return self._stream[stream_id]

def _get_local_settings(self) -> Dict[int, int]:
"""
Return the local HTTP/3 settings.
"""
settings = {
Setting.QPACK_MAX_TABLE_CAPACITY: self._max_table_capacity,
Setting.QPACK_BLOCKED_STREAMS: self._blocked_streams,
Setting.DUMMY: 1,
}
if self._enable_webtransport:
settings[Setting.H3_DATAGRAM] = 1
settings[Setting.SETTINGS_ENABLE_WEBTRANSPORT] = 1
return settings

def _handle_control_frame(self, frame_type: int, frame_data: bytes) -> None:
"""
Handle a frame received on the peer's control stream.
Expand All @@ -528,6 +586,7 @@ def _handle_control_frame(self, frame_type: int, frame_data: bytes) -> None:
if self._settings_received:
raise FrameUnexpected("SETTINGS have already been received")
settings = parse_settings(frame_data)
self._validate_settings(settings)
encoder = self._encoder.apply_settings(
max_table_capacity=settings.get(Setting.QPACK_MAX_TABLE_CAPACITY, 0),
blocked_streams=settings.get(Setting.QPACK_BLOCKED_STREAMS, 0),
Expand Down Expand Up @@ -616,7 +675,7 @@ def _handle_request_or_push_frame(
stream_ended=stream_ended,
)
)
elif stream.frame_type == FrameType.PUSH_PROMISE and stream.push_id is None:
elif frame_type == FrameType.PUSH_PROMISE and stream.push_id is None:
if not self._is_client:
raise FrameUnexpected("Clients must not send PUSH_PROMISE")
frame_buf = Buffer(data=frame_data)
Expand Down Expand Up @@ -670,14 +729,7 @@ def _init_connection(self) -> None:
self._quic.send_stream_data(
self._local_control_stream_id,
encode_frame(
FrameType.SETTINGS,
encode_settings(
{
Setting.QPACK_MAX_TABLE_CAPACITY: self._max_table_capacity,
Setting.QPACK_BLOCKED_STREAMS: self._blocked_streams,
Setting.DUMMY: 1,
}
),
FrameType.SETTINGS, encode_settings(self._get_local_settings())
),
)
if self._is_client and self._max_push_id is not None:
Expand All @@ -694,6 +746,17 @@ def _init_connection(self) -> None:
StreamType.QPACK_DECODER
)

def _receive_datagram(self, data: bytes) -> List[H3Event]:
"""
Handle a datagram.
"""
buf = Buffer(data=data)
try:
flow_id = buf.pull_uint_var()
except BufferReadError:
raise ProtocolError("Could not parse flow ID")
return [DatagramReceived(data=data[buf.tell() :], flow_id=flow_id)]

def _receive_request_or_push_data(
self, stream: H3Stream, data: bytes, stream_ended: bool
) -> List[H3Event]:
Expand All @@ -708,6 +771,22 @@ def _receive_request_or_push_data(
if stream.blocked:
return http_events

# shortcut for WEBTRANSPORT_STREAM frame fragments
if (
stream.frame_type == FrameType.WEBTRANSPORT_STREAM
and stream.session_id is not None
):
http_events.append(
WebTransportStreamDataReceived(
data=stream.buffer,
session_id=stream.session_id,
stream_id=stream.stream_id,
stream_ended=stream_ended,
)
)
stream.buffer = b""
return http_events

# shortcut for DATA frame fragments
if (
stream.frame_type == FrameType.DATA
Expand Down Expand Up @@ -751,6 +830,25 @@ def _receive_request_or_push_data(
break
consumed = buf.tell()

# WEBTRANSPORT_STREAM frames last until the end of the stream
if stream.frame_type == FrameType.WEBTRANSPORT_STREAM:
stream.session_id = stream.frame_size
stream.frame_size = None

frame_data = stream.buffer[consumed:]
stream.buffer = b""

if frame_data or stream_ended:
http_events.append(
WebTransportStreamDataReceived(
data=frame_data,
session_id=stream.session_id,
stream_id=stream.stream_id,
stream_ended=stream_ended,
)
)
return http_events

# log frame
if (
self._quic_logger is not None
Expand All @@ -771,17 +869,19 @@ def _receive_request_or_push_data(

# read available data
frame_data = buf.pull_bytes(chunk_size)
frame_type = stream.frame_type
consumed = buf.tell()

# detect end of frame
stream.frame_size -= chunk_size
if not stream.frame_size:
stream.frame_size = None
stream.frame_type = None

try:
http_events.extend(
self._handle_request_or_push_frame(
frame_type=stream.frame_type,
frame_type=frame_type,
frame_data=frame_data,
stream=stream,
stream_ended=stream.ended and buf.eof(),
Expand Down Expand Up @@ -811,7 +911,9 @@ def _receive_stream_data_uni(
unblocked_streams: Set[int] = set()

while (
stream.stream_type in (StreamType.PUSH, StreamType.CONTROL) or not buf.eof()
stream.stream_type
in (StreamType.PUSH, StreamType.CONTROL, StreamType.WEBTRANSPORT)
or not buf.eof()
):
# fetch stream type for unidirectional streams
if stream.stream_type is None:
Expand Down Expand Up @@ -866,6 +968,28 @@ def _receive_stream_data_uni(
stream.buffer = stream.buffer[consumed:]

return self._receive_request_or_push_data(stream, b"", stream_ended)
elif stream.stream_type == StreamType.WEBTRANSPORT:
# fetch session id
if stream.session_id is None:
try:
stream.session_id = buf.pull_uint_var()
except BufferReadError:
break
consumed = buf.tell()

frame_data = stream.buffer[consumed:]
stream.buffer = b""

if frame_data or stream_ended:
http_events.append(
WebTransportStreamDataReceived(
data=frame_data,
session_id=stream.session_id,
stream_ended=stream.ended,
stream_id=stream.stream_id,
)
)
return http_events
elif stream.stream_type == StreamType.QPACK_DECODER:
# feed unframed data to decoder
data = buf.pull_bytes(buf.capacity - buf.tell())
Expand Down Expand Up @@ -915,3 +1039,28 @@ def _receive_stream_data_uni(
)

return http_events

def _validate_settings(self, settings: Dict[int, int]) -> None:
if Setting.H3_DATAGRAM in settings:
if settings[Setting.H3_DATAGRAM] not in (0, 1):
raise SettingsError("H3_DATAGRAM setting must be 0 or 1")

if (
settings[Setting.H3_DATAGRAM] == 1
and self._quic._remote_max_datagram_frame_size is None
):
raise SettingsError(
"H3_DATAGRAM requires max_datagram_frame_size transport parameter"
)

if Setting.SETTINGS_ENABLE_WEBTRANSPORT in settings:
if settings[Setting.SETTINGS_ENABLE_WEBTRANSPORT] not in (0, 1):
raise SettingsError(
"SETTINGS_ENABLE_WEBTRANSPORT setting must be 0 or 1"
)

if (
settings[Setting.SETTINGS_ENABLE_WEBTRANSPORT] == 1
and settings.get(Setting.H3_DATAGRAM) != 1
):
raise SettingsError("SETTINGS_ENABLE_WEBTRANSPORT requires H3_DATAGRAM")
Loading

0 comments on commit 59134f3

Please sign in to comment.