Skip to content

Commit

Permalink
StreamShutdown with Pending Receive (microsoft#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
nibanks authored Apr 24, 2021
1 parent f1dc5cc commit 5c22710
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 26 deletions.
20 changes: 18 additions & 2 deletions src/core/stream_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,28 @@ QuicStreamRecvShutdown(
goto Exit;
}

Stream->SendCloseErrorCode = ErrorCode;
Stream->Flags.SentStopSending = TRUE;
//
// Disable all future receive events.
//
Stream->Flags.ReceiveEnabled = FALSE;
Stream->Flags.ReceiveDataPending = FALSE;
Stream->Flags.ReceiveCallPending = FALSE;

if (Stream->RecvMaxLength != UINT64_MAX) {
//
// The peer has already gracefully closed, but we just haven't drained
// the receives to that point. Ignore this abort from the app and jump
// right to the closed state.
//
Stream->Flags.RemoteCloseFin = TRUE;
Stream->Flags.RemoteCloseAcked = TRUE;
Silent = TRUE; // To indicate we try to shutdown complete.
goto Exit;
}

Stream->SendCloseErrorCode = ErrorCode;
Stream->Flags.SentStopSending = TRUE;

//
// Queue up a stop sending frame to be sent.
//
Expand Down
113 changes: 105 additions & 8 deletions src/inc/msquic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ struct MsQuicListener {
QUIC_STATUS InitStatus;

MsQuicListener(
const MsQuicRegistration& Registration,
_In_ const MsQuicRegistration& Registration,
_In_ QUIC_LISTENER_CALLBACK_HANDLER Handler,
_In_ void* Context = nullptr
) noexcept {
Expand Down Expand Up @@ -501,12 +501,54 @@ struct MsQuicListener {
operator HQUIC () const noexcept { return Handle; }
};

struct MsQuicAutoAcceptListener : public MsQuicListener {
const MsQuicConfiguration& Configuration;
QUIC_CONNECTION_CALLBACK_HANDLER ConnectionHandler;
void* ConnectionContext;

MsQuicAutoAcceptListener(
_In_ const MsQuicRegistration& Registration,
_In_ const MsQuicConfiguration& Config,
_In_ QUIC_CONNECTION_CALLBACK_HANDLER _ConnectionHandler,
_In_ void* _ConnectionContext = nullptr
) noexcept :
MsQuicListener(Registration, ListenerCallback, this),
Configuration(Config),
ConnectionHandler(_ConnectionHandler),
ConnectionContext(_ConnectionContext)
{ }

private:

static
_IRQL_requires_max_(PASSIVE_LEVEL)
_Function_class_(QUIC_LISTENER_CALLBACK)
QUIC_STATUS
QUIC_API
ListenerCallback(
_In_ HQUIC /* Listener */,
_In_opt_ void* Context,
_Inout_ QUIC_LISTENER_EVENT* Event
)
{
auto pThis = (MsQuicAutoAcceptListener*)Context;
CXPLAT_DBG_ASSERT(pThis);
switch (Event->Type) {
case QUIC_LISTENER_EVENT_NEW_CONNECTION:
MsQuic->SetCallbackHandler(Event->NEW_CONNECTION.Connection, (void*)pThis->ConnectionHandler, pThis->ConnectionContext);
return MsQuic->ConnectionSetConfiguration(Event->NEW_CONNECTION.Connection, pThis->Configuration);
default:
return QUIC_STATUS_INVALID_STATE;
}
}
};

struct MsQuicConnection {
HQUIC Handle { nullptr };
QUIC_STATUS InitStatus;

MsQuicConnection(
const MsQuicRegistration& Registration,
_In_ const MsQuicRegistration& Registration,
_In_ QUIC_CONNECTION_CALLBACK_HANDLER Handler,
_In_ void* Context = nullptr
) noexcept {
Expand All @@ -526,7 +568,7 @@ struct MsQuicConnection {
}

MsQuicConnection(
HQUIC ConnectionHandle,
_In_ HQUIC ConnectionHandle,
_In_ QUIC_CONNECTION_CALLBACK_HANDLER Handler,
_In_ void* Context = nullptr
) noexcept {
Expand Down Expand Up @@ -625,6 +667,57 @@ struct MsQuicConnection {
operator HQUIC () const noexcept { return Handle; }
};

struct MsQuicStream {
HQUIC Handle { nullptr };
QUIC_STATUS InitStatus;

MsQuicStream(
_In_ const MsQuicConnection& Connection,
_In_ QUIC_STREAM_OPEN_FLAGS Flags,
_In_ QUIC_STREAM_CALLBACK_HANDLER Handler,
_In_ void* Context = nullptr
) noexcept {
if (!Connection.IsValid()) {
InitStatus = Connection.GetInitStatus();
return;
}
if (QUIC_FAILED(
InitStatus =
MsQuic->StreamOpen(
Connection,
Flags,
Handler,
Context,
&Handle))) {
Handle = nullptr;
}
}

~MsQuicStream() noexcept {
if (Handle) {
MsQuic->StreamClose(Handle);
}
}

QUIC_STATUS
Send(
_In_reads_(BufferCount) _Pre_defensive_
const QUIC_BUFFER* const Buffers,
_In_ uint32_t BufferCount = 1,
_In_ QUIC_SEND_FLAGS Flags = QUIC_SEND_FLAG_NONE,
_In_opt_ void* ClientSendContext = nullptr
)
{
return MsQuic->StreamSend(Handle, Buffers, BufferCount, Flags, ClientSendContext);
}

QUIC_STATUS GetInitStatus() const noexcept { return InitStatus; }
bool IsValid() const { return QUIC_SUCCEEDED(InitStatus); }
MsQuicStream(MsQuicStream& other) = delete;
MsQuicStream operator=(MsQuicStream& Other) = delete;
operator HQUIC () const noexcept { return Handle; }
};

struct ConnectionScope {
HQUIC Handle;
ConnectionScope() noexcept : Handle(nullptr) { }
Expand Down Expand Up @@ -667,14 +760,18 @@ struct QuicBufferScope {
// Abstractions for platform specific types/interfaces
//

struct EventScope {
struct CxPlatEvent {
CXPLAT_EVENT Handle;
EventScope() noexcept { CxPlatEventInitialize(&Handle, FALSE, FALSE); }
EventScope(bool ManualReset) noexcept { CxPlatEventInitialize(&Handle, ManualReset, FALSE); }
EventScope(CXPLAT_EVENT event) noexcept : Handle(event) { }
~EventScope() noexcept { CxPlatEventUninitialize(Handle); }
CxPlatEvent() noexcept { CxPlatEventInitialize(&Handle, FALSE, FALSE); }
CxPlatEvent(bool ManualReset) noexcept { CxPlatEventInitialize(&Handle, ManualReset, FALSE); }
CxPlatEvent(CXPLAT_EVENT event) noexcept : Handle(event) { }
~CxPlatEvent() noexcept { CxPlatEventUninitialize(Handle); }
CXPLAT_EVENT* operator &() noexcept { return &Handle; }
operator CXPLAT_EVENT() const noexcept { return Handle; }
void Set() { CxPlatEventSet(Handle); }
void Reset() { CxPlatEventReset(Handle); }
void WaitForever() { CxPlatEventWaitForever(Handle); }
bool WaitTimeout(uint32_t TimeoutMs) { return CxPlatEventWaitWithTimeout(Handle, TimeoutMs); }
};

#ifdef CXPLAT_HASH_MIN_SIZE
Expand Down
2 changes: 1 addition & 1 deletion src/perf/bin/appmain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ QuicUserMain(
_In_ const QUIC_CREDENTIAL_CONFIG* SelfSignedCredConfig,
_In_opt_z_ const char* FileName
) {
EventScope StopEvent {true};
CxPlatEvent StopEvent {true};

QUIC_STATUS Status;

Expand Down
2 changes: 1 addition & 1 deletion src/perf/lib/RpsClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class RpsClient : public PerfBase {
CXPLAT_EVENT* CompletionEvent {nullptr};
QUIC_ADDR LocalAddresses[RPS_MAX_CLIENT_PORT_COUNT];
uint32_t ActiveConnections {0};
EventScope AllConnected {true};
CxPlatEvent AllConnected {true};
uint64_t StartedRequests {0};
uint64_t SendCompletedRequests {0};
uint64_t CompletedRequests {0};
Expand Down
13 changes: 12 additions & 1 deletion src/test/MsQuicTests.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ typedef union QUIC_ABORTIVE_TRANSFER_FLAGS {
uint32_t WaitForStream : 1;
uint32_t ShutdownDirection : 2;
uint32_t UnidirectionalStream : 1;
uint32_t PauseReceive : 1;
uint32_t PendReceive : 1;
};
uint32_t IntValue;
} QUIC_ABORTIVE_TRANSFER_FLAGS;
Expand Down Expand Up @@ -332,6 +334,11 @@ QuicTestAckSendDelay(
_In_ int Family
);

void
QuicTestAbortReceive(
_In_ bool IsPaused
);

//
// QuicDrill tests
//
Expand Down Expand Up @@ -769,4 +776,8 @@ typedef struct {
#define IOCTL_QUIC_RUN_EXPIRED_CLIENT_CERT \
QUIC_CTL_CODE(62, METHOD_BUFFERED, FILE_WRITE_DATA)

#define QUIC_MAX_IOCTL_FUNC_CODE 62
#define IOCTL_QUIC_RUN_ABORT_RECEIVE \
QUIC_CTL_CODE(63, METHOD_BUFFERED, FILE_WRITE_DATA)
// BOOLEAN

#define QUIC_MAX_IOCTL_FUNC_CODE 63
20 changes: 20 additions & 0 deletions src/test/bin/quic_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,26 @@ TEST_P(WithFamilyArgs, AckSendDelay) {
}
}

TEST(Misc, AbortPausedReceive) {
TestLogger Logger("AbortPausedReceive");
if (TestingKernelMode) {
BOOLEAN IsPaused = TRUE;
ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_ABORT_RECEIVE, IsPaused));
} else {
QuicTestAbortReceive(true);
}
}

TEST(Misc, AbortPendingReceive) {
TestLogger Logger("AbortPendingReceive");
if (TestingKernelMode) {
BOOLEAN IsPaused = FALSE;
ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_ABORT_RECEIVE, IsPaused));
} else {
QuicTestAbortReceive(false);
}
}

TEST(Drill, VarIntEncoder) {
TestLogger Logger("QuicDrillTestVarIntEncoder");
if (TestingKernelMode) {
Expand Down
14 changes: 12 additions & 2 deletions src/test/bin/quic_gtest.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,25 @@ struct AbortiveArgs {
for (uint32_t WaitForStream : { 1 })
for (uint32_t ShutdownDirection : { 0, 1, 2 })
for (uint32_t UnidirectionStream : { 0, 1 })
list.push_back({ Family, {{ DelayStreamCreation, SendDataOnStream, ClientShutdown, DelayClientShutdown, WaitForStream, ShutdownDirection, UnidirectionStream }} });
for (uint32_t PauseReceive : { 0, 1 })
for (uint32_t PendReceive : { 0, 1 })
list.push_back({ Family, {{ DelayStreamCreation, SendDataOnStream, ClientShutdown, DelayClientShutdown, WaitForStream, ShutdownDirection, UnidirectionStream, PauseReceive, PendReceive }} });
return list;
}
};

std::ostream& operator << (std::ostream& o, const AbortiveArgs& args) {
return o <<
(args.Family == 4 ? "v4" : "v6") << "/" <<
args.Flags.IntValue;
args.Flags.DelayStreamCreation << "/" <<
args.Flags.SendDataOnStream << "/" <<
args.Flags.ClientShutdown << "/" <<
args.Flags.DelayClientShutdown << "/" <<
args.Flags.WaitForStream << "/" <<
args.Flags.ShutdownDirection << "/" <<
args.Flags.UnidirectionalStream << "/" <<
args.Flags.PauseReceive << "/" <<
args.Flags.PendReceive;
}

class WithAbortiveArgs : public testing::Test,
Expand Down
9 changes: 8 additions & 1 deletion src/test/bin/winkernel/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ size_t QUIC_IOCTL_BUFFER_SIZES[] =
sizeof(QUIC_RUN_CRED_VALIDATION),
sizeof(QUIC_RUN_CRED_VALIDATION),
sizeof(QUIC_RUN_CRED_VALIDATION),
sizeof(QUIC_RUN_CRED_VALIDATION)
sizeof(QUIC_RUN_CRED_VALIDATION),
sizeof(BOOLEAN)
};

CXPLAT_STATIC_ASSERT(
Expand All @@ -457,6 +458,7 @@ typedef union {
QUIC_RUN_VERSION_NEGOTIATION_EXT VersionNegotiationExtParams;
QUIC_RUN_CONNECT_CLIENT_CERT ConnectClientCertParams;
QUIC_RUN_CRED_VALIDATION CredValidationParams;
BOOLEAN IsPaused;

} QUIC_IOCTL_PARAMS;

Expand Down Expand Up @@ -1013,6 +1015,11 @@ QuicTestCtlEvtIoDeviceControl(
&Params->CredValidationParams.CredConfig));
break;

case IOCTL_QUIC_RUN_ABORT_RECEIVE:
CXPLAT_FRE_ASSERT(Params != nullptr);
QuicTestCtlRun(QuicTestAbortReceive(Params->IsPaused));
break;

default:
Status = STATUS_NOT_IMPLEMENTED;
break;
Expand Down
Loading

0 comments on commit 5c22710

Please sign in to comment.