diff --git a/api/transport/data_channel_transport_interface.h b/api/transport/data_channel_transport_interface.h index cdae6fee19..27d7de6364 100644 --- a/api/transport/data_channel_transport_interface.h +++ b/api/transport/data_channel_transport_interface.h @@ -118,6 +118,8 @@ class DataChannelTransportInterface { // Note: the default implementation always returns false (as it assumes no one // has implemented the interface). This default implementation is temporary. virtual bool IsReadyToSend() const = 0; + + virtual size_t buffered_amount(int channel_id) const = 0; }; } // namespace webrtc diff --git a/media/sctp/dcsctp_transport.cc b/media/sctp/dcsctp_transport.cc index 013f48b550..bff9c29d17 100644 --- a/media/sctp/dcsctp_transport.cc +++ b/media/sctp/dcsctp_transport.cc @@ -375,6 +375,12 @@ absl::optional DcSctpTransport::max_inbound_streams() const { return socket_->options().announced_maximum_incoming_streams; } +size_t DcSctpTransport::buffered_amount(int sid) const { + if (!socket_) + return 0; + return socket_->buffered_amount(dcsctp::StreamID(sid)); +} + void DcSctpTransport::set_debug_name_for_testing(const char* debug_name) { debug_name_ = debug_name; } diff --git a/media/sctp/dcsctp_transport.h b/media/sctp/dcsctp_transport.h index c021422615..aa301d8496 100644 --- a/media/sctp/dcsctp_transport.h +++ b/media/sctp/dcsctp_transport.h @@ -66,6 +66,7 @@ class DcSctpTransport : public cricket::SctpTransportInternal, int max_message_size() const override; absl::optional max_outbound_streams() const override; absl::optional max_inbound_streams() const override; + size_t buffered_amount(int sid) const override; void set_debug_name_for_testing(const char* debug_name) override; private: diff --git a/media/sctp/sctp_transport_internal.h b/media/sctp/sctp_transport_internal.h index 8a7450f405..705f5bd3e6 100644 --- a/media/sctp/sctp_transport_internal.h +++ b/media/sctp/sctp_transport_internal.h @@ -140,6 +140,8 @@ class SctpTransportInternal { virtual absl::optional max_outbound_streams() const = 0; // Returns the current negotiated max # of inbound streams. virtual absl::optional max_inbound_streams() const = 0; + // Returns the amount of buffered data in the send queue for a stream. + virtual size_t buffered_amount(int sid) const = 0; // Helper for debugging. virtual void set_debug_name_for_testing(const char* debug_name) = 0; diff --git a/pc/data_channel_controller.cc b/pc/data_channel_controller.cc index 208160e1c1..fbe639f96b 100644 --- a/pc/data_channel_controller.cc +++ b/pc/data_channel_controller.cc @@ -89,6 +89,14 @@ void DataChannelController::OnChannelStateChanged( })); } +size_t DataChannelController::buffered_amount(StreamId sid) const { + RTC_DCHECK_RUN_ON(network_thread()); + if (!data_channel_transport_) { + return 0; + } + return data_channel_transport_->buffered_amount(sid.stream_id_int()); +} + void DataChannelController::OnDataReceived( int channel_id, DataMessageType type, diff --git a/pc/data_channel_controller.h b/pc/data_channel_controller.h index d7ac3706b0..d2a9a1a135 100644 --- a/pc/data_channel_controller.h +++ b/pc/data_channel_controller.h @@ -54,6 +54,7 @@ class DataChannelController : public SctpDataChannelControllerInterface, void RemoveSctpDataStream(StreamId sid) override; void OnChannelStateChanged(SctpDataChannel* channel, DataChannelInterface::DataState state) override; + size_t buffered_amount(StreamId sid) const override; // Implements DataChannelSink. void OnDataReceived(int channel_id, diff --git a/pc/data_channel_controller_unittest.cc b/pc/data_channel_controller_unittest.cc index 7d4e60467e..caf9a76c41 100644 --- a/pc/data_channel_controller_unittest.cc +++ b/pc/data_channel_controller_unittest.cc @@ -41,6 +41,7 @@ class MockDataChannelTransport : public DataChannelTransportInterface { MOCK_METHOD(RTCError, CloseChannel, (int channel_id), (override)); MOCK_METHOD(void, SetDataSink, (DataChannelSink * sink), (override)); MOCK_METHOD(bool, IsReadyToSend, (), (const, override)); + MOCK_METHOD(size_t, buffered_amount, (int channel_id), (const, override)); }; // Convenience class for tests to ensure that shutdown methods for DCC @@ -167,6 +168,20 @@ TEST_F(DataChannelControllerTest, MaxChannels) { } } +TEST_F(DataChannelControllerTest, BufferedAmountIncludesFromTransport) { + NiceMock transport; + EXPECT_CALL(transport, buffered_amount(0)).WillOnce(Return(4711)); + ON_CALL(*pc_, GetSctpSslRole_n).WillByDefault([&]() { + return rtc::SSL_CLIENT; + }); + + DataChannelControllerForTest dcc(pc_.get(), &transport); + auto dc = dcc.InternalCreateDataChannelWithProxy( + "label", InternalDataChannelInit(DataChannelInit())) + .MoveValue(); + EXPECT_EQ(dc->buffered_amount(), 4711u); +} + // Test that while a data channel is in the `kClosing` state, its StreamId does // not get re-used for new channels. Only once the state reaches `kClosed` // should a StreamId be available again for allocation. diff --git a/pc/data_channel_integrationtest.cc b/pc/data_channel_integrationtest.cc index 5a8004c72a..a31481d634 100644 --- a/pc/data_channel_integrationtest.cc +++ b/pc/data_channel_integrationtest.cc @@ -1042,14 +1042,16 @@ TEST_P(DataChannelIntegrationTest, kDefaultTimeout); // Cause a temporary network outage virtual_socket_server()->set_drop_probability(1.0); - // Fill the buffer until queued data starts to build + // Fill the SCTP socket buffer until queued data starts to build. + constexpr size_t kBufferedDataInSctpSocket = 2'000'000; size_t packet_counter = 0; - while (caller()->data_channel()->buffered_amount() < 1 && + while (caller()->data_channel()->buffered_amount() < + kBufferedDataInSctpSocket && packet_counter < 10000) { packet_counter++; caller()->data_channel()->Send(DataBuffer("Sent while blocked")); } - if (caller()->data_channel()->buffered_amount()) { + if (caller()->data_channel()->buffered_amount() > kBufferedDataInSctpSocket) { RTC_LOG(LS_INFO) << "Buffered data after " << packet_counter << " packets"; } else { RTC_LOG(LS_INFO) << "No buffered data after " << packet_counter diff --git a/pc/sctp_data_channel.cc b/pc/sctp_data_channel.cc index 8aa5fbd974..7ec314d2f7 100644 --- a/pc/sctp_data_channel.cc +++ b/pc/sctp_data_channel.cc @@ -485,7 +485,11 @@ Priority SctpDataChannel::priority() const { uint64_t SctpDataChannel::buffered_amount() const { RTC_DCHECK_RUN_ON(network_thread_); - return queued_send_data_.byte_count(); + uint64_t buffered_amount = queued_send_data_.byte_count(); + if (controller_ != nullptr && id_n_.has_value()) { + buffered_amount += controller_->buffered_amount(*id_n_); + } + return buffered_amount; } void SctpDataChannel::Close() { diff --git a/pc/sctp_data_channel.h b/pc/sctp_data_channel.h index fdbf2053e3..0be234bd16 100644 --- a/pc/sctp_data_channel.h +++ b/pc/sctp_data_channel.h @@ -55,6 +55,7 @@ class SctpDataChannelControllerInterface { // Notifies the controller of state changes. virtual void OnChannelStateChanged(SctpDataChannel* data_channel, DataChannelInterface::DataState state) = 0; + virtual size_t buffered_amount(StreamId sid) const = 0; protected: virtual ~SctpDataChannelControllerInterface() {} diff --git a/pc/sctp_transport.cc b/pc/sctp_transport.cc index 7f55e39d9e..5f505e0296 100644 --- a/pc/sctp_transport.cc +++ b/pc/sctp_transport.cc @@ -100,6 +100,12 @@ bool SctpTransport::IsReadyToSend() const { return internal_sctp_transport_->ReadyToSendData(); } +size_t SctpTransport::buffered_amount(int channel_id) const { + RTC_DCHECK_RUN_ON(owner_thread_); + RTC_DCHECK(internal_sctp_transport_); + return internal_sctp_transport_->buffered_amount(channel_id); +} + rtc::scoped_refptr SctpTransport::dtls_transport() const { RTC_DCHECK_RUN_ON(owner_thread_); diff --git a/pc/sctp_transport.h b/pc/sctp_transport.h index 076dee5318..79cb3aed2c 100644 --- a/pc/sctp_transport.h +++ b/pc/sctp_transport.h @@ -52,6 +52,7 @@ class SctpTransport : public SctpTransportInterface, RTCError CloseChannel(int channel_id) override; void SetDataSink(DataChannelSink* sink) override; bool IsReadyToSend() const override; + size_t buffered_amount(int channel_id) const override; // Internal functions void Clear(); diff --git a/pc/sctp_transport_unittest.cc b/pc/sctp_transport_unittest.cc index d18543f20c..f0401c1b10 100644 --- a/pc/sctp_transport_unittest.cc +++ b/pc/sctp_transport_unittest.cc @@ -63,6 +63,7 @@ class FakeCricketSctpTransport : public cricket::SctpTransportInternal { absl::optional max_inbound_streams() const override { return max_inbound_streams_; } + size_t buffered_amount(int sid) const override { return 0; } void SendSignalAssociationChangeCommunicationUp() { ASSERT_TRUE(on_connected_callback_); @@ -212,5 +213,4 @@ TEST_F(SctpTransportTest, CloseWhenTransportCloses) { ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(), kDefaultTimeout); } - } // namespace webrtc diff --git a/pc/test/fake_data_channel_controller.h b/pc/test/fake_data_channel_controller.h index 89cdce738d..c65449b010 100644 --- a/pc/test/fake_data_channel_controller.h +++ b/pc/test/fake_data_channel_controller.h @@ -128,6 +128,8 @@ class FakeDataChannelController } } + size_t buffered_amount(webrtc::StreamId sid) const override { return 0; } + // Set true to emulate the SCTP stream being blocked by congestion control. void set_send_blocked(bool blocked) { network_thread_->BlockingCall([&]() { diff --git a/test/pc/sctp/fake_sctp_transport.h b/test/pc/sctp/fake_sctp_transport.h index 96c126640c..6aef57a241 100644 --- a/test/pc/sctp/fake_sctp_transport.h +++ b/test/pc/sctp/fake_sctp_transport.h @@ -41,9 +41,14 @@ class FakeSctpTransport : public cricket::SctpTransportInternal { bool ReadyToSendData() override { return true; } void set_debug_name_for_testing(const char* debug_name) override {} - int max_message_size() const { return max_message_size_; } - absl::optional max_outbound_streams() const { return absl::nullopt; } - absl::optional max_inbound_streams() const { return absl::nullopt; } + int max_message_size() const override { return max_message_size_; } + absl::optional max_outbound_streams() const override { + return absl::nullopt; + } + absl::optional max_inbound_streams() const override { + return absl::nullopt; + } + size_t buffered_amount(int sid) const override { return 0; } int local_port() const { RTC_DCHECK(local_port_); return *local_port_;