Skip to content

Commit

Permalink
Updated PeerConnection integration test to fix race condition.
Browse files Browse the repository at this point in the history
The PeerConnection integration test was creating TurnServers on the
stack on the signaling thread. This could cause a race condition problem
when the test was being taken down. Since the turn server was destructed
on the signaling thread, a socket might still try and send to it after
it was destroyed causing a seg fault. This change creates/destroys the
TestTurnServers on the network thread to fix this issue.

Bug: None
Change-Id: I080098502b737f0972ce2fa5357920de057a3312
Reviewed-on: https://webrtc-review.googlesource.com/81301
Reviewed-by: Qingsi Wang <[email protected]>
Reviewed-by: Steve Anton <[email protected]>
Commit-Queue: Seth Hampson <[email protected]>
Cr-Commit-Position: refs/heads/master@{#23590}
  • Loading branch information
Seth Hampson authored and Commit Bot committed Jun 13, 2018
1 parent 2cf61e3 commit aed7164
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 62 deletions.
15 changes: 14 additions & 1 deletion p2p/base/testturnserver.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "rtc_base/ssladapter.h"
#include "rtc_base/sslidentity.h"
#include "rtc_base/thread.h"
#include "rtc_base/thread_checker.h"

namespace cricket {

Expand Down Expand Up @@ -65,24 +66,33 @@ class TestTurnServer : public TurnAuthInterface {
server_.set_auth_hook(this);
}

~TestTurnServer() { RTC_DCHECK(thread_checker_.CalledOnValidThread()); }

void set_enable_otu_nonce(bool enable) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
server_.set_enable_otu_nonce(enable);
}

TurnServer* server() { return &server_; }
TurnServer* server() {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
return &server_;
}

void set_redirect_hook(TurnRedirectInterface* redirect_hook) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
server_.set_redirect_hook(redirect_hook);
}

void set_enable_permission_checks(bool enable) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
server_.set_enable_permission_checks(enable);
}

void AddInternalSocket(const rtc::SocketAddress& int_addr,
ProtocolType proto,
bool ignore_bad_cert = true,
const std::string& common_name = "test turn server") {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
if (proto == cricket::PROTO_UDP) {
server_.AddInternalSocket(
rtc::AsyncUDPSocket::Create(thread_->socketserver(), int_addr),
Expand Down Expand Up @@ -115,6 +125,7 @@ class TestTurnServer : public TurnAuthInterface {
// Finds the first allocation in the server allocation map with a source
// ip and port matching the socket address provided.
TurnServerAllocation* FindAllocation(const rtc::SocketAddress& src) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
const TurnServer::AllocationMap& map = server_.allocations();
for (TurnServer::AllocationMap::const_iterator it = map.begin();
it != map.end(); ++it) {
Expand All @@ -130,11 +141,13 @@ class TestTurnServer : public TurnAuthInterface {
// Obviously, do not use this in a production environment.
virtual bool GetKey(const std::string& username, const std::string& realm,
std::string* key) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
return ComputeStunCredentialHash(username, realm, username, key);
}

TurnServer server_;
rtc::Thread* thread_;
rtc::ThreadChecker thread_checker_;
};

} // namespace cricket
Expand Down
3 changes: 2 additions & 1 deletion p2p/base/turnport.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ class TurnPort : public Port {

rtc::AsyncInvoker invoker_;

// Optional TurnCustomizer that can modify outgoing messages.
// Optional TurnCustomizer that can modify outgoing messages. Once set, this
// must outlive the TurnPort's lifetime.
webrtc::TurnCustomizer *turn_customizer_ = nullptr;

friend class TurnEntry;
Expand Down
25 changes: 25 additions & 0 deletions p2p/base/turnserver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ TurnServer::TurnServer(rtc::Thread* thread)
}

TurnServer::~TurnServer() {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
for (InternalSocketMap::iterator it = server_sockets_.begin();
it != server_sockets_.end(); ++it) {
rtc::AsyncPacketSocket* socket = it->first;
Expand All @@ -144,13 +145,15 @@ TurnServer::~TurnServer() {

void TurnServer::AddInternalSocket(rtc::AsyncPacketSocket* socket,
ProtocolType proto) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
RTC_DCHECK(server_sockets_.end() == server_sockets_.find(socket));
server_sockets_[socket] = proto;
socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket);
}

void TurnServer::AddInternalServerSocket(rtc::AsyncSocket* socket,
ProtocolType proto) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
RTC_DCHECK(server_listen_sockets_.end() ==
server_listen_sockets_.find(socket));
server_listen_sockets_[socket] = proto;
Expand All @@ -160,17 +163,20 @@ void TurnServer::AddInternalServerSocket(rtc::AsyncSocket* socket,
void TurnServer::SetExternalSocketFactory(
rtc::PacketSocketFactory* factory,
const rtc::SocketAddress& external_addr) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
external_socket_factory_.reset(factory);
external_addr_ = external_addr;
}

void TurnServer::OnNewInternalConnection(rtc::AsyncSocket* socket) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
RTC_DCHECK(server_listen_sockets_.find(socket) !=
server_listen_sockets_.end());
AcceptConnection(socket);
}

void TurnServer::AcceptConnection(rtc::AsyncSocket* server_socket) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
// Check if someone is trying to connect to us.
rtc::SocketAddress accept_addr;
rtc::AsyncSocket* accepted_socket = server_socket->Accept(&accept_addr);
Expand All @@ -187,13 +193,15 @@ void TurnServer::AcceptConnection(rtc::AsyncSocket* server_socket) {

void TurnServer::OnInternalSocketClose(rtc::AsyncPacketSocket* socket,
int err) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
DestroyInternalSocket(socket);
}

void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket,
const char* data, size_t size,
const rtc::SocketAddress& addr,
const rtc::PacketTime& packet_time) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
// Fail if the packet is too small to even contain a channel header.
if (size < TURN_CHANNEL_HEADER_SIZE) {
return;
Expand All @@ -219,6 +227,7 @@ void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket,

void TurnServer::HandleStunMessage(TurnServerConnection* conn, const char* data,
size_t size) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
TurnMessage msg;
rtc::ByteBufferReader buf(data, size);
if (!msg.Read(&buf) || (buf.Length() > 0)) {
Expand Down Expand Up @@ -285,6 +294,7 @@ void TurnServer::HandleStunMessage(TurnServerConnection* conn, const char* data,
}

bool TurnServer::GetKey(const StunMessage* msg, std::string* key) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
const StunByteStringAttribute* username_attr =
msg->GetByteString(STUN_ATTR_USERNAME);
if (!username_attr) {
Expand All @@ -299,6 +309,7 @@ bool TurnServer::CheckAuthorization(TurnServerConnection* conn,
const StunMessage* msg,
const char* data, size_t size,
const std::string& key) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
// RFC 5389, 10.2.2.
RTC_DCHECK(IsStunRequestType(msg->type()));
const StunByteStringAttribute* mi_attr =
Expand Down Expand Up @@ -357,6 +368,7 @@ bool TurnServer::CheckAuthorization(TurnServerConnection* conn,

void TurnServer::HandleBindingRequest(TurnServerConnection* conn,
const StunMessage* req) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
StunMessage response;
InitResponse(req, &response);

Expand All @@ -371,6 +383,7 @@ void TurnServer::HandleBindingRequest(TurnServerConnection* conn,
void TurnServer::HandleAllocateRequest(TurnServerConnection* conn,
const TurnMessage* msg,
const std::string& key) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
// Check the parameters in the request.
const StunUInt32Attribute* transport_attr =
msg->GetUInt32(STUN_ATTR_REQUESTED_TRANSPORT);
Expand Down Expand Up @@ -400,6 +413,7 @@ void TurnServer::HandleAllocateRequest(TurnServerConnection* conn,
}

std::string TurnServer::GenerateNonce(int64_t now) const {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
// Generate a nonce of the form hex(now + HMAC-MD5(nonce_key_, now))
std::string input(reinterpret_cast<const char*>(&now), sizeof(now));
std::string nonce = rtc::hex_encode(input.c_str(), input.size());
Expand All @@ -410,6 +424,7 @@ std::string TurnServer::GenerateNonce(int64_t now) const {
}

bool TurnServer::ValidateNonce(const std::string& nonce) const {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
// Check the size.
if (nonce.size() != kNonceSize) {
return false;
Expand All @@ -435,13 +450,15 @@ bool TurnServer::ValidateNonce(const std::string& nonce) const {
}

TurnServerAllocation* TurnServer::FindAllocation(TurnServerConnection* conn) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
AllocationMap::const_iterator it = allocations_.find(*conn);
return (it != allocations_.end()) ? it->second.get() : nullptr;
}

TurnServerAllocation* TurnServer::CreateAllocation(TurnServerConnection* conn,
int proto,
const std::string& key) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
rtc::AsyncPacketSocket* external_socket = (external_socket_factory_) ?
external_socket_factory_->CreateUdpSocket(external_addr_, 0, 0) : NULL;
if (!external_socket) {
Expand All @@ -459,6 +476,7 @@ TurnServerAllocation* TurnServer::CreateAllocation(TurnServerConnection* conn,
void TurnServer::SendErrorResponse(TurnServerConnection* conn,
const StunMessage* req,
int code, const std::string& reason) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
TurnMessage resp;
InitErrorResponse(req, code, reason, &resp);
RTC_LOG(LS_INFO) << "Sending error response, type=" << resp.type()
Expand All @@ -469,6 +487,7 @@ void TurnServer::SendErrorResponse(TurnServerConnection* conn,
void TurnServer::SendErrorResponseWithRealmAndNonce(
TurnServerConnection* conn, const StunMessage* msg,
int code, const std::string& reason) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
TurnMessage resp;
InitErrorResponse(msg, code, reason, &resp);

Expand All @@ -487,6 +506,7 @@ void TurnServer::SendErrorResponseWithRealmAndNonce(
void TurnServer::SendErrorResponseWithAlternateServer(
TurnServerConnection* conn, const StunMessage* msg,
const rtc::SocketAddress& addr) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
TurnMessage resp;
InitErrorResponse(msg, STUN_ERROR_TRY_ALTERNATE,
STUN_ERROR_REASON_TRY_ALTERNATE_SERVER, &resp);
Expand All @@ -496,6 +516,7 @@ void TurnServer::SendErrorResponseWithAlternateServer(
}

void TurnServer::SendStun(TurnServerConnection* conn, StunMessage* msg) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
rtc::ByteBufferWriter buf;
// Add a SOFTWARE attribute if one is set.
if (!software_.empty()) {
Expand All @@ -508,11 +529,13 @@ void TurnServer::SendStun(TurnServerConnection* conn, StunMessage* msg) {

void TurnServer::Send(TurnServerConnection* conn,
const rtc::ByteBufferWriter& buf) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
rtc::PacketOptions options;
conn->socket()->SendTo(buf.Data(), buf.Length(), conn->src(), options);
}

void TurnServer::OnAllocationDestroyed(TurnServerAllocation* allocation) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
// Removing the internal socket if the connection is not udp.
rtc::AsyncPacketSocket* socket = allocation->conn()->socket();
InternalSocketMap::iterator iter = server_sockets_.find(socket);
Expand All @@ -532,6 +555,7 @@ void TurnServer::OnAllocationDestroyed(TurnServerAllocation* allocation) {
}

void TurnServer::DestroyInternalSocket(rtc::AsyncPacketSocket* socket) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
InternalSocketMap::iterator iter = server_sockets_.find(socket);
if (iter != server_sockets_.end()) {
rtc::AsyncPacketSocket* socket = iter->first;
Expand All @@ -547,6 +571,7 @@ void TurnServer::DestroyInternalSocket(rtc::AsyncPacketSocket* socket) {
}

void TurnServer::FreeSockets() {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
sockets_to_delete_.clear();
}

Expand Down
42 changes: 35 additions & 7 deletions p2p/base/turnserver.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "rtc_base/messagequeue.h"
#include "rtc_base/sigslot.h"
#include "rtc_base/socketaddress.h"
#include "rtc_base/thread_checker.h"

namespace rtc {
class ByteBufferWriter;
Expand Down Expand Up @@ -178,30 +179,54 @@ class TurnServer : public sigslot::has_slots<> {
~TurnServer() override;

// Gets/sets the realm value to use for the server.
const std::string& realm() const { return realm_; }
void set_realm(const std::string& realm) { realm_ = realm; }
const std::string& realm() const {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
return realm_;
}
void set_realm(const std::string& realm) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
realm_ = realm;
}

// Gets/sets the value for the SOFTWARE attribute for TURN messages.
const std::string& software() const { return software_; }
void set_software(const std::string& software) { software_ = software; }
const std::string& software() const {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
return software_;
}
void set_software(const std::string& software) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
software_ = software;
}

const AllocationMap& allocations() const { return allocations_; }
const AllocationMap& allocations() const {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
return allocations_;
}

// Sets the authentication callback; does not take ownership.
void set_auth_hook(TurnAuthInterface* auth_hook) { auth_hook_ = auth_hook; }
void set_auth_hook(TurnAuthInterface* auth_hook) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
auth_hook_ = auth_hook;
}

void set_redirect_hook(TurnRedirectInterface* redirect_hook) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
redirect_hook_ = redirect_hook;
}

void set_enable_otu_nonce(bool enable) { enable_otu_nonce_ = enable; }
void set_enable_otu_nonce(bool enable) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
enable_otu_nonce_ = enable;
}

// If set to true, reject CreatePermission requests to RFC1918 addresses.
void set_reject_private_addresses(bool filter) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
reject_private_addresses_ = filter;
}

void set_enable_permission_checks(bool enable) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
enable_permission_checks_ = enable;
}

Expand All @@ -218,12 +243,14 @@ class TurnServer : public sigslot::has_slots<> {
const rtc::SocketAddress& address);
// For testing only.
std::string SetTimestampForNextNonce(int64_t timestamp) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
ts_for_next_nonce_ = timestamp;
return GenerateNonce(timestamp);
}

void SetStunMessageObserver(
std::unique_ptr<StunMessageObserver> observer) {
RTC_DCHECK(thread_checker_.CalledOnValidThread());
stun_message_observer_ = std::move(observer);
}

Expand Down Expand Up @@ -282,6 +309,7 @@ class TurnServer : public sigslot::has_slots<> {
ProtocolType> ServerSocketMap;

rtc::Thread* thread_;
rtc::ThreadChecker thread_checker_;
std::string nonce_key_;
std::string realm_;
std::string software_;
Expand Down
Loading

0 comments on commit aed7164

Please sign in to comment.