Skip to content

Commit

Permalink
pc: Simplify StreamId class
Browse files Browse the repository at this point in the history
Before this CL, the StreamId class represented either a valid SCTP
stream ID, or "nothing", which means that it was a wrapped
absl::optional. Since created data channels don't have a SCTP stream ID
until it's known whether this peer will use odd or even numbers, the
"nothing" value was used for that state.

This unfortunately made it a bit hard to work with objects of this type,
as one always had to check if it contained a value. And even if a caller
would check this, and then pass the StreamId to a different function,
that function would have to do the check itself (often as a RTC_DCHECK)
since the passed StreamId always could have that state.

This CL simply extracts the "absl::optional" part of it, forcing holders
to wrap it in an optional type - when it can be "nothing". But allowing
the other code to just pass StreamId that can't be "nothing". That
simplifies the code a bit, potentially removing some bugs.

Bug: chromium:41221056
Change-Id: I93104cdd5d2f5fc1dbeb9d9dfc4cf361f11a9d68
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/342440
Reviewed-by: Florent Castelli <[email protected]>
Reviewed-by: Tomas Gunnarsson <[email protected]>
Commit-Queue: Victor Boivie <[email protected]>
Cr-Commit-Position: refs/heads/main@{#41880}
  • Loading branch information
Victor Boivie authored and WebRTC LUCI CQ committed Mar 12, 2024
1 parent b4913a5 commit cd3d29b
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 123 deletions.
63 changes: 36 additions & 27 deletions pc/data_channel_controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <utility>

#include "absl/algorithm/container.h"
#include "absl/types/optional.h"
#include "api/peer_connection_interface.h"
#include "api/rtc_error.h"
#include "pc/peer_connection_internal.h"
Expand Down Expand Up @@ -53,7 +54,6 @@ RTCError DataChannelController::SendData(

void DataChannelController::AddSctpDataStream(StreamId sid) {
RTC_DCHECK_RUN_ON(network_thread());
RTC_DCHECK(sid.HasValue());
if (data_channel_transport_) {
data_channel_transport_->OpenChannel(sid.stream_id_int());
}
Expand Down Expand Up @@ -99,7 +99,7 @@ void DataChannelController::OnDataReceived(
return;

auto it = absl::c_find_if(sctp_data_channels_n_, [&](const auto& c) {
return c->sid_n().stream_id_int() == channel_id;
return c->sid_n().has_value() && c->sid_n()->stream_id_int() == channel_id;
});

if (it != sctp_data_channels_n_.end())
Expand All @@ -109,7 +109,7 @@ void DataChannelController::OnDataReceived(
void DataChannelController::OnChannelClosing(int channel_id) {
RTC_DCHECK_RUN_ON(network_thread());
auto it = absl::c_find_if(sctp_data_channels_n_, [&](const auto& c) {
return c->sid_n().stream_id_int() == channel_id;
return c->sid_n().has_value() && c->sid_n()->stream_id_int() == channel_id;
});

if (it != sctp_data_channels_n_.end())
Expand All @@ -134,7 +134,7 @@ void DataChannelController::OnReadyToSend() {
RTC_DCHECK_RUN_ON(network_thread());
auto copy = sctp_data_channels_n_;
for (const auto& channel : copy) {
if (channel->sid_n().HasValue()) {
if (channel->sid_n().has_value()) {
channel->OnTransportReady();
} else {
// This happens for role==SSL_SERVER channels when we get notified by
Expand All @@ -157,7 +157,9 @@ void DataChannelController::OnTransportClosed(RTCError error) {
temp_sctp_dcs.swap(sctp_data_channels_n_);
for (const auto& channel : temp_sctp_dcs) {
channel->OnTransportChannelClosed(error);
sid_allocator_.ReleaseSid(channel->sid_n());
if (channel->sid_n().has_value()) {
sid_allocator_.ReleaseSid(*channel->sid_n());
}
}
}

Expand Down Expand Up @@ -257,13 +259,12 @@ void DataChannelController::OnDataChannelOpenMessage(

// RTC_RUN_ON(network_thread())
RTCError DataChannelController::ReserveOrAllocateSid(
StreamId& sid,
absl::optional<StreamId>& sid,
absl::optional<rtc::SSLRole> fallback_ssl_role) {
if (sid.HasValue()) {
return sid_allocator_.ReserveSid(sid)
if (sid.has_value()) {
return sid_allocator_.ReserveSid(*sid)
? RTCError::OK()
: RTCError(RTCErrorType::INVALID_RANGE,
"StreamId out of range or reserved.");
: RTCError(RTCErrorType::INVALID_RANGE, "StreamId reserved.");
}

// Attempt to allocate an ID based on the negotiated role.
Expand All @@ -272,26 +273,35 @@ RTCError DataChannelController::ReserveOrAllocateSid(
role = fallback_ssl_role;
if (role) {
sid = sid_allocator_.AllocateSid(*role);
if (!sid.HasValue())
if (!sid.has_value())
return RTCError(RTCErrorType::RESOURCE_EXHAUSTED);
}
// When we get here, we may still not have an ID, but that's a supported case
// whereby an id will be assigned later.
RTC_DCHECK(sid.HasValue() || !role);
RTC_DCHECK(sid.has_value() || !role);
return RTCError::OK();
}

// RTC_RUN_ON(network_thread())
RTCErrorOr<rtc::scoped_refptr<SctpDataChannel>>
DataChannelController::CreateDataChannel(const std::string& label,
InternalDataChannelInit& config) {
StreamId sid(config.id);
absl::optional<StreamId> sid = absl::nullopt;
if (config.id != -1) {
if (config.id < 0 || config.id > cricket::kMaxSctpSid) {
return RTCError(RTCErrorType::INVALID_RANGE, "StreamId out of range.");
}
sid = StreamId(config.id);
}

RTCError err = ReserveOrAllocateSid(sid, config.fallback_ssl_role);
if (!err.ok())
return err;

// In case `sid` has changed. Update `config` accordingly.
config.id = sid.stream_id_int();
if (sid.has_value()) {
config.id = sid->stream_id_int();
}

rtc::scoped_refptr<SctpDataChannel> channel = SctpDataChannel::Create(
weak_factory_.GetWeakPtr(), label, data_channel_transport_ != nullptr,
Expand All @@ -300,8 +310,8 @@ DataChannelController::CreateDataChannel(const std::string& label,
sctp_data_channels_n_.push_back(channel);

// If we have an id already, notify the transport.
if (sid.HasValue())
AddSctpDataStream(sid);
if (sid.has_value())
AddSctpDataStream(*sid);

return channel;
}
Expand All @@ -319,7 +329,6 @@ DataChannelController::InternalCreateDataChannelWithProxy(

bool ready_to_send = false;
InternalDataChannelInit new_config = config;
StreamId sid(new_config.id);
auto ret = network_thread()->BlockingCall(
[&]() -> RTCErrorOr<rtc::scoped_refptr<SctpDataChannel>> {
RTC_DCHECK_RUN_ON(network_thread());
Expand Down Expand Up @@ -361,16 +370,16 @@ void DataChannelController::AllocateSctpSids(rtc::SSLRole role) {
std::vector<rtc::scoped_refptr<SctpDataChannel>> channels_to_close;
for (auto it = sctp_data_channels_n_.begin();
it != sctp_data_channels_n_.end();) {
if (!(*it)->sid_n().HasValue()) {
StreamId sid = sid_allocator_.AllocateSid(role);
if (sid.HasValue()) {
(*it)->SetSctpSid_n(sid);
AddSctpDataStream(sid);
if (!(*it)->sid_n().has_value()) {
absl::optional<StreamId> sid = sid_allocator_.AllocateSid(role);
if (sid.has_value()) {
(*it)->SetSctpSid_n(*sid);
AddSctpDataStream(*sid);
if (ready_to_send) {
RTC_LOG(LS_INFO) << "AllocateSctpSids: Id assigned, ready to send.";
(*it)->OnTransportReady();
}
channels_to_update.push_back(std::make_pair((*it).get(), sid));
channels_to_update.push_back(std::make_pair((*it).get(), *sid));
} else {
channels_to_close.push_back(std::move(*it));
it = sctp_data_channels_n_.erase(it);
Expand All @@ -391,8 +400,8 @@ void DataChannelController::OnSctpDataChannelClosed(SctpDataChannel* channel) {
RTC_DCHECK_RUN_ON(network_thread());
// After the closing procedure is done, it's safe to use this ID for
// another data channel.
if (channel->sid_n().HasValue()) {
sid_allocator_.ReleaseSid(channel->sid_n());
if (channel->sid_n().has_value()) {
sid_allocator_.ReleaseSid(*channel->sid_n());
}
auto it = absl::c_find_if(sctp_data_channels_n_,
[&](const auto& c) { return c.get() == channel; });
Expand Down Expand Up @@ -423,8 +432,8 @@ void DataChannelController::NotifyDataChannelsOfTransportCreated() {
RTC_DCHECK(data_channel_transport_);

for (const auto& channel : sctp_data_channels_n_) {
if (channel->sid_n().HasValue())
AddSctpDataStream(channel->sid_n());
if (channel->sid_n().has_value())
AddSctpDataStream(*channel->sid_n());
channel->OnTransportChannelCreated();
}
}
Expand Down
2 changes: 1 addition & 1 deletion pc/data_channel_controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class DataChannelController : public SctpDataChannelControllerInterface,
// will still be unassigned upon return, but will be assigned later.
// If the pool has been exhausted or a sid has already been reserved, an
// error will be returned.
RTCError ReserveOrAllocateSid(StreamId& sid,
RTCError ReserveOrAllocateSid(absl::optional<StreamId>& sid,
absl::optional<rtc::SSLRole> fallback_ssl_role)
RTC_RUN_ON(network_thread());

Expand Down
34 changes: 18 additions & 16 deletions pc/data_channel_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class SctpDataChannelTest : public ::testing::Test {
StreamId sid(0);
network_thread_.BlockingCall([&]() {
RTC_DCHECK_RUN_ON(&network_thread_);
if (!inner_channel_->sid_n().HasValue()) {
if (!inner_channel_->sid_n().has_value()) {
inner_channel_->SetSctpSid_n(sid);
controller_->AddSctpDataStream(sid);
}
Expand All @@ -115,7 +115,6 @@ class SctpDataChannelTest : public ::testing::Test {
// to run on the network thread.
void SetChannelSid(const rtc::scoped_refptr<SctpDataChannel>& channel,
StreamId sid) {
RTC_DCHECK(sid.HasValue());
network_thread_.BlockingCall([&]() {
channel->SetSctpSid_n(sid);
controller_->AddSctpDataStream(sid);
Expand Down Expand Up @@ -172,7 +171,7 @@ TEST_F(SctpDataChannelTest, VerifyConfigurationGetters) {
// Check the non-const part of the configuration.
EXPECT_EQ(channel_->id(), init_.id);
network_thread_.BlockingCall(
[&]() { EXPECT_EQ(inner_channel_->sid_n(), StreamId()); });
[&]() { EXPECT_EQ(inner_channel_->sid_n(), absl::nullopt); });

SetChannelReady();
EXPECT_EQ(channel_->id(), 0);
Expand All @@ -188,12 +187,14 @@ TEST_F(SctpDataChannelTest, ConnectedToTransportOnCreated) {
EXPECT_TRUE(controller_->IsConnected(dc.get()));

// The sid is not set yet, so it should not have added the streams.
StreamId sid = network_thread_.BlockingCall([&]() { return dc->sid_n(); });
EXPECT_FALSE(controller_->IsStreamAdded(sid));
absl::optional<StreamId> sid =
network_thread_.BlockingCall([&]() { return dc->sid_n(); });
EXPECT_FALSE(sid.has_value());

SetChannelSid(dc, StreamId(0));
sid = network_thread_.BlockingCall([&]() { return dc->sid_n(); });
EXPECT_TRUE(controller_->IsStreamAdded(sid));
ASSERT_TRUE(sid.has_value());
EXPECT_TRUE(controller_->IsStreamAdded(*sid));
}

// Tests the state of the data channel.
Expand Down Expand Up @@ -1035,14 +1036,14 @@ TEST_F(SctpSidAllocatorTest, SctpIdAllocationNoReuse) {
StreamId old_id(1);
EXPECT_TRUE(allocator_.ReserveSid(old_id));

StreamId new_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_TRUE(new_id.HasValue());
absl::optional<StreamId> new_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_TRUE(new_id.has_value());
EXPECT_NE(old_id, new_id);

old_id = StreamId(0);
EXPECT_TRUE(allocator_.ReserveSid(old_id));
new_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_TRUE(new_id.HasValue());
EXPECT_TRUE(new_id.has_value());
EXPECT_NE(old_id, new_id);
}

Expand All @@ -1053,17 +1054,18 @@ TEST_F(SctpSidAllocatorTest, SctpIdReusedForRemovedDataChannel) {
EXPECT_TRUE(allocator_.ReserveSid(odd_id));
EXPECT_TRUE(allocator_.ReserveSid(even_id));

StreamId allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_EQ(odd_id.stream_id_int() + 2, allocated_id.stream_id_int());
absl::optional<StreamId> allocated_id =
allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_EQ(odd_id.stream_id_int() + 2, allocated_id->stream_id_int());

allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_EQ(even_id.stream_id_int() + 2, allocated_id.stream_id_int());
EXPECT_EQ(even_id.stream_id_int() + 2, allocated_id->stream_id_int());

allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_EQ(odd_id.stream_id_int() + 4, allocated_id.stream_id_int());
EXPECT_EQ(odd_id.stream_id_int() + 4, allocated_id->stream_id_int());

allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_EQ(even_id.stream_id_int() + 4, allocated_id.stream_id_int());
EXPECT_EQ(even_id.stream_id_int() + 4, allocated_id->stream_id_int());

allocator_.ReleaseSid(odd_id);
allocator_.ReleaseSid(even_id);
Expand All @@ -1077,10 +1079,10 @@ TEST_F(SctpSidAllocatorTest, SctpIdReusedForRemovedDataChannel) {

// Verifies that used higher ids are not reused.
allocated_id = allocator_.AllocateSid(rtc::SSL_SERVER);
EXPECT_EQ(odd_id.stream_id_int() + 6, allocated_id.stream_id_int());
EXPECT_EQ(odd_id.stream_id_int() + 6, allocated_id->stream_id_int());

allocated_id = allocator_.AllocateSid(rtc::SSL_CLIENT);
EXPECT_EQ(even_id.stream_id_int() + 6, allocated_id.stream_id_int());
EXPECT_EQ(even_id.stream_id_int() + 6, allocated_id->stream_id_int());
}

// Code coverage tests for default implementations in data_channel_interface.*.
Expand Down
Loading

0 comments on commit cd3d29b

Please sign in to comment.