diff --git a/api/peerconnectioninterface.h b/api/peerconnectioninterface.h index fd578539679..be668ad5ef2 100644 --- a/api/peerconnectioninterface.h +++ b/api/peerconnectioninterface.h @@ -109,6 +109,7 @@ #include "rtc_base/rtccertificate.h" #include "rtc_base/rtccertificategenerator.h" #include "rtc_base/socketaddress.h" +#include "rtc_base/sslcertificate.h" #include "rtc_base/sslstreamadapter.h" namespace rtc { @@ -1190,6 +1191,7 @@ struct PeerConnectionDependencies final { // Optional dependencies std::unique_ptr allocator; std::unique_ptr cert_generator; + std::unique_ptr tls_cert_verifier; }; // PeerConnectionFactoryInterface is the factory interface used for creating diff --git a/p2p/base/basicpacketsocketfactory.cc b/p2p/base/basicpacketsocketfactory.cc index 86b9e8541ce..31761c6c83a 100644 --- a/p2p/base/basicpacketsocketfactory.cc +++ b/p2p/base/basicpacketsocketfactory.cc @@ -170,6 +170,7 @@ AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( ssl_adapter->SetAlpnProtocols(tcp_options.tls_alpn_protocols); ssl_adapter->SetEllipticCurves(tcp_options.tls_elliptic_curves); + ssl_adapter->SetCertVerifier(tcp_options.tls_cert_verifier); socket = ssl_adapter; diff --git a/p2p/base/packetsocketfactory.h b/p2p/base/packetsocketfactory.h index c4a60eae765..e5df8fd2d14 100644 --- a/p2p/base/packetsocketfactory.h +++ b/p2p/base/packetsocketfactory.h @@ -16,6 +16,7 @@ #include "rtc_base/constructormagic.h" #include "rtc_base/proxyinfo.h" +#include "rtc_base/sslcertificate.h" namespace rtc { @@ -27,6 +28,9 @@ struct PacketSocketTcpOptions { int opts = 0; std::vector tls_alpn_protocols; std::vector tls_elliptic_curves; + // An optional custom SSL certificate verifier that an API user can provide to + // inject their own certificate verification logic. + SSLCertificateVerifier* tls_cert_verifier = nullptr; }; class AsyncPacketSocket; diff --git a/p2p/base/port_unittest.cc b/p2p/base/port_unittest.cc index c79fac8cfca..5a0dc1a25ed 100644 --- a/p2p/base/port_unittest.cc +++ b/p2p/base/port_unittest.cc @@ -544,7 +544,7 @@ class PortTest : public testing::Test, public sigslot::has_slots<> { &main_, socket_factory, MakeNetwork(addr), 0, 0, username_, password_, ProtocolAddress(server_addr, int_proto), kRelayCredentials, 0, std::string(), std::vector(), std::vector(), - nullptr); + nullptr, nullptr); } RelayPort* CreateGturnPort(const SocketAddress& addr, ProtocolType int_proto, ProtocolType ext_proto) { @@ -1486,7 +1486,7 @@ TEST_F(PortTest, TestTcpNoDelay) { } TEST_F(PortTest, TestDelayedBindingUdp) { - FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + FakeAsyncPacketSocket* socket = new FakeAsyncPacketSocket(); FakePacketSocketFactory socket_factory; socket_factory.set_next_udp_socket(socket); @@ -1502,7 +1502,7 @@ TEST_F(PortTest, TestDelayedBindingUdp) { } TEST_F(PortTest, TestDelayedBindingTcp) { - FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + FakeAsyncPacketSocket* socket = new FakeAsyncPacketSocket(); FakePacketSocketFactory socket_factory; socket_factory.set_next_server_tcp_socket(socket); @@ -1525,7 +1525,7 @@ void PortTest::TestCrossFamilyPorts(int type) { SocketAddress("2001:db8::1", 0), SocketAddress("2001:db8::2", 0)}; for (int i = 0; i < 4; i++) { - FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + FakeAsyncPacketSocket* socket = new FakeAsyncPacketSocket(); if (type == SOCK_DGRAM) { factory.set_next_udp_socket(socket); ports[i].reset(CreateUdpPort(addresses[i], &factory)); @@ -1595,7 +1595,7 @@ TEST_F(PortTest, TestUdpV6CrossTypePorts) { SocketAddress("fe80::2", 0), SocketAddress("::1", 0)}; for (int i = 0; i < 4; i++) { - FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + FakeAsyncPacketSocket* socket = new FakeAsyncPacketSocket(); factory.set_next_udp_socket(socket); ports[i].reset(CreateUdpPort(addresses[i], &factory)); socket->set_state(AsyncPacketSocket::STATE_BINDING); diff --git a/p2p/base/portallocator.h b/p2p/base/portallocator.h index 9ee248b6656..30ace5eba77 100644 --- a/p2p/base/portallocator.h +++ b/p2p/base/portallocator.h @@ -21,6 +21,7 @@ #include "rtc_base/helpers.h" #include "rtc_base/proxyinfo.h" #include "rtc_base/sigslot.h" +#include "rtc_base/sslcertificate.h" #include "rtc_base/thread.h" #include "rtc_base/thread_checker.h" @@ -180,6 +181,7 @@ struct RelayServerConfig { TlsCertPolicy tls_cert_policy = TlsCertPolicy::TLS_CERT_POLICY_SECURE; std::vector tls_alpn_protocols; std::vector tls_elliptic_curves; + rtc::SSLCertificateVerifier* tls_cert_verifier = nullptr; }; class PortAllocatorSession : public sigslot::has_slots<> { diff --git a/p2p/base/testturnserver.h b/p2p/base/testturnserver.h index 3e670f36d1e..067faddc868 100644 --- a/p2p/base/testturnserver.h +++ b/p2p/base/testturnserver.h @@ -53,9 +53,11 @@ class TestTurnServer : public TurnAuthInterface { TestTurnServer(rtc::Thread* thread, const rtc::SocketAddress& int_addr, const rtc::SocketAddress& udp_ext_addr, - ProtocolType int_protocol = PROTO_UDP) + ProtocolType int_protocol = PROTO_UDP, + bool ignore_bad_cert = true, + const std::string& common_name = "test turn server") : server_(thread), thread_(thread) { - AddInternalSocket(int_addr, int_protocol); + AddInternalSocket(int_addr, int_protocol, ignore_bad_cert, common_name); server_.SetExternalSocketFactory(new rtc::BasicPacketSocketFactory(thread), udp_ext_addr); server_.set_realm(kTestRealm); @@ -78,7 +80,9 @@ class TestTurnServer : public TurnAuthInterface { } void AddInternalSocket(const rtc::SocketAddress& int_addr, - ProtocolType proto) { + ProtocolType proto, + bool ignore_bad_cert = true, + const std::string& common_name = "test turn server") { if (proto == cricket::PROTO_UDP) { server_.AddInternalSocket( rtc::AsyncUDPSocket::Create(thread_->socketserver(), int_addr), @@ -96,8 +100,8 @@ class TestTurnServer : public TurnAuthInterface { rtc::SSLAdapter* adapter = rtc::SSLAdapter::Create(socket); adapter->SetRole(rtc::SSL_SERVER); adapter->SetIdentity( - rtc::SSLIdentity::Generate("test turn server", rtc::KeyParams())); - adapter->SetIgnoreBadCert(true); + rtc::SSLIdentity::Generate(common_name, rtc::KeyParams())); + adapter->SetIgnoreBadCert(ignore_bad_cert); socket = adapter; } socket->Bind(int_addr); diff --git a/p2p/base/turnport.cc b/p2p/base/turnport.cc index 57ce32b82cc..a09c6403c50 100644 --- a/p2p/base/turnport.cc +++ b/p2p/base/turnport.cc @@ -199,13 +199,9 @@ TurnPort::TurnPort(rtc::Thread* thread, int server_priority, const std::string& origin, webrtc::TurnCustomizer* customizer) - : Port(thread, - RELAY_PORT_TYPE, - factory, - network, - username, - password), + : Port(thread, RELAY_PORT_TYPE, factory, network, username, password), server_address_(server_address), + tls_cert_verifier_(nullptr), credentials_(credentials), socket_(socket), resolver_(NULL), @@ -233,7 +229,8 @@ TurnPort::TurnPort(rtc::Thread* thread, const std::string& origin, const std::vector& tls_alpn_protocols, const std::vector& tls_elliptic_curves, - webrtc::TurnCustomizer* customizer) + webrtc::TurnCustomizer* customizer, + rtc::SSLCertificateVerifier* tls_cert_verifier) : Port(thread, RELAY_PORT_TYPE, factory, @@ -245,6 +242,7 @@ TurnPort::TurnPort(rtc::Thread* thread, server_address_(server_address), tls_alpn_protocols_(tls_alpn_protocols), tls_elliptic_curves_(tls_elliptic_curves), + tls_cert_verifier_(tls_cert_verifier), credentials_(credentials), socket_(NULL), resolver_(NULL), @@ -374,6 +372,7 @@ bool TurnPort::CreateTurnClientSocket() { tcp_options.opts = opts; tcp_options.tls_alpn_protocols = tls_alpn_protocols_; tcp_options.tls_elliptic_curves = tls_elliptic_curves_; + tcp_options.tls_cert_verifier = tls_cert_verifier_; socket_ = socket_factory()->CreateClientTcpSocket( rtc::SocketAddress(Network()->GetBestIP(), 0), server_address_.address, proxy(), user_agent(), tcp_options); diff --git a/p2p/base/turnport.h b/p2p/base/turnport.h index 7ebb5f1af65..b0def2a898c 100644 --- a/p2p/base/turnport.h +++ b/p2p/base/turnport.h @@ -22,6 +22,7 @@ #include "p2p/client/basicportallocator.h" #include "rtc_base/asyncinvoker.h" #include "rtc_base/asyncpacketsocket.h" +#include "rtc_base/sslcertificate.h" namespace rtc { class AsyncResolver; @@ -67,24 +68,26 @@ class TurnPort : public Port { // Create a TURN port that will use a new socket, bound to |network| and // using a port in the range between |min_port| and |max_port|. - static TurnPort* Create(rtc::Thread* thread, - rtc::PacketSocketFactory* factory, - rtc::Network* network, - uint16_t min_port, - uint16_t max_port, - const std::string& username, // ice username. - const std::string& password, // ice password. - const ProtocolAddress& server_address, - const RelayCredentials& credentials, - int server_priority, - const std::string& origin, - const std::vector& tls_alpn_protocols, - const std::vector& tls_elliptic_curves, - webrtc::TurnCustomizer* customizer) { + static TurnPort* Create( + rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + uint16_t min_port, + uint16_t max_port, + const std::string& username, // ice username. + const std::string& password, // ice password. + const ProtocolAddress& server_address, + const RelayCredentials& credentials, + int server_priority, + const std::string& origin, + const std::vector& tls_alpn_protocols, + const std::vector& tls_elliptic_curves, + webrtc::TurnCustomizer* customizer, + rtc::SSLCertificateVerifier* tls_cert_verifier = nullptr) { return new TurnPort(thread, factory, network, min_port, max_port, username, password, server_address, credentials, server_priority, origin, tls_alpn_protocols, tls_elliptic_curves, - customizer); + customizer, tls_cert_verifier); } ~TurnPort() override; @@ -214,7 +217,8 @@ class TurnPort : public Port { const std::string& origin, const std::vector& tls_alpn_protocols, const std::vector& tls_elliptic_curves, - webrtc::TurnCustomizer* customizer); + webrtc::TurnCustomizer* customizer, + rtc::SSLCertificateVerifier* tls_cert_verifier = nullptr); // NOTE: This method needs to be accessible for StacPort // return true if entry was created (i.e channel_number consumed). @@ -303,6 +307,7 @@ class TurnPort : public Port { TlsCertPolicy tls_cert_policy_ = TlsCertPolicy::TLS_CERT_POLICY_SECURE; std::vector tls_alpn_protocols_; std::vector tls_elliptic_curves_; + rtc::SSLCertificateVerifier* tls_cert_verifier_; RelayCredentials credentials_; AttemptedServerSet attempted_server_addresses_; diff --git a/p2p/client/turnportfactory.cc b/p2p/client/turnportfactory.cc index 9f24f5f4b60..fc4f9d31e8c 100644 --- a/p2p/client/turnportfactory.cc +++ b/p2p/client/turnportfactory.cc @@ -43,22 +43,12 @@ std::unique_ptr TurnPortFactory::Create( const CreateRelayPortArgs& args, int min_port, int max_port) { - TurnPort* port = TurnPort::Create( - args.network_thread, - args.socket_factory, - args.network, - min_port, - max_port, - args.username, - args.password, - *args.server_address, - args.config->credentials, - args.config->priority, - args.origin, - args.config->tls_alpn_protocols, - args.config->tls_elliptic_curves, - args.turn_customizer); + args.network_thread, args.socket_factory, args.network, min_port, + max_port, args.username, args.password, *args.server_address, + args.config->credentials, args.config->priority, args.origin, + args.config->tls_alpn_protocols, args.config->tls_elliptic_curves, + args.turn_customizer, args.config->tls_cert_verifier); port->SetTlsCertPolicy(args.config->tls_cert_policy); return std::unique_ptr(port); } diff --git a/pc/peerconnection.cc b/pc/peerconnection.cc index 552bd2f5d56..6557fca23ba 100644 --- a/pc/peerconnection.cc +++ b/pc/peerconnection.cc @@ -904,8 +904,10 @@ bool PeerConnection::Initialize( "PeerConnectionObserver"; return false; } + observer_ = dependencies.observer; port_allocator_ = std::move(dependencies.allocator); + tls_cert_verifier_ = std::move(dependencies.tls_cert_verifier); // The port allocator lives on the network thread and should be initialized // there. @@ -4674,6 +4676,11 @@ bool PeerConnection::InitializePortAllocator_n( ConvertIceTransportTypeToCandidateFilter(configuration.type)); port_allocator_->set_max_ipv6_networks(configuration.max_ipv6_networks); + if (tls_cert_verifier_ != nullptr) { + for (auto& turn_server : turn_servers) { + turn_server.tls_cert_verifier = tls_cert_verifier_.get(); + } + } // Call this last since it may create pooled allocator sessions using the // properties set above. port_allocator_->SetConfiguration( diff --git a/pc/peerconnection.h b/pc/peerconnection.h index 1c8c3e01f31..754913152d8 100644 --- a/pc/peerconnection.h +++ b/pc/peerconnection.h @@ -906,6 +906,7 @@ class PeerConnection : public PeerConnectionInternal, PeerConnectionInterface::RTCConfiguration configuration_; std::unique_ptr port_allocator_; + std::unique_ptr tls_cert_verifier_; int port_allocator_flags_ = 0; // One PeerConnection has only one RTCP CNAME. diff --git a/pc/peerconnection_integrationtest.cc b/pc/peerconnection_integrationtest.cc index 6616243ac1f..969fc804e11 100644 --- a/pc/peerconnection_integrationtest.cc +++ b/pc/peerconnection_integrationtest.cc @@ -55,6 +55,7 @@ #include "rtc_base/fakenetwork.h" #include "rtc_base/firewallsocketserver.h" #include "rtc_base/gunit.h" +#include "rtc_base/testcertificateverifier.h" #include "rtc_base/virtualsocketserver.h" #include "test/gmock.h" @@ -227,7 +228,9 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, rtc::Thread* network_thread, rtc::Thread* worker_thread) { PeerConnectionWrapper* client(new PeerConnectionWrapper(debug_name)); - if (!client->Init(nullptr, nullptr, nullptr, std::move(cert_generator), + webrtc::PeerConnectionDependencies dependencies(nullptr); + dependencies.cert_generator = std::move(cert_generator); + if (!client->Init(nullptr, nullptr, nullptr, std::move(dependencies), network_thread, worker_thread)) { delete client; return nullptr; @@ -579,13 +582,12 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, explicit PeerConnectionWrapper(const std::string& debug_name) : debug_name_(debug_name) {} - bool Init( - const MediaConstraintsInterface* constraints, - const PeerConnectionFactory::Options* options, - const PeerConnectionInterface::RTCConfiguration* config, - std::unique_ptr cert_generator, - rtc::Thread* network_thread, - rtc::Thread* worker_thread) { + bool Init(const MediaConstraintsInterface* constraints, + const PeerConnectionFactory::Options* options, + const PeerConnectionInterface::RTCConfiguration* config, + webrtc::PeerConnectionDependencies dependencies, + rtc::Thread* network_thread, + rtc::Thread* worker_thread) { // There's an error in this test code if Init ends up being called twice. RTC_DCHECK(!peer_connection_); RTC_DCHECK(!peer_connection_factory_); @@ -625,17 +627,17 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, if (config) { sdp_semantics_ = config->sdp_semantics; } + + dependencies.allocator = std::move(port_allocator); peer_connection_ = - CreatePeerConnection(std::move(port_allocator), constraints, config, - std::move(cert_generator)); + CreatePeerConnection(constraints, config, std::move(dependencies)); return peer_connection_.get() != nullptr; } rtc::scoped_refptr CreatePeerConnection( - std::unique_ptr port_allocator, const MediaConstraintsInterface* constraints, const PeerConnectionInterface::RTCConfiguration* config, - std::unique_ptr cert_generator) { + webrtc::PeerConnectionDependencies dependencies) { PeerConnectionInterface::RTCConfiguration modified_config; // If |config| is null, this will result in a default configuration being // used. @@ -648,9 +650,15 @@ class PeerConnectionWrapper : public webrtc::PeerConnectionObserver, // ratios and not specific resolutions, is this even necessary? modified_config.set_cpu_adaptation(false); + // Use the legacy interface. + if (constraints != nullptr) { + return peer_connection_factory_->CreatePeerConnection( + modified_config, constraints, std::move(dependencies.allocator), + std::move(dependencies.cert_generator), this); + } + dependencies.observer = this; return peer_connection_factory_->CreatePeerConnection( - modified_config, constraints, std::move(port_allocator), - std::move(cert_generator), this); + modified_config, std::move(dependencies)); } void set_signaling_message_receiver( @@ -1156,19 +1164,21 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { const MediaConstraintsInterface* constraints, const PeerConnectionFactory::Options* options, const RTCConfiguration* config, - std::unique_ptr cert_generator) { + webrtc::PeerConnectionDependencies dependencies) { RTCConfiguration modified_config; if (config) { modified_config = *config; } modified_config.sdp_semantics = sdp_semantics_; - if (!cert_generator) { - cert_generator = rtc::MakeUnique(); + if (!dependencies.cert_generator) { + dependencies.cert_generator = + rtc::MakeUnique(); } std::unique_ptr client( new PeerConnectionWrapper(debug_name)); + if (!client->Init(constraints, options, &modified_config, - std::move(cert_generator), network_thread_.get(), + std::move(dependencies), network_thread_.get(), worker_thread_.get())) { return nullptr; } @@ -1191,11 +1201,13 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { // callee PeerConnections. SdpSemantics original_semantics = sdp_semantics_; sdp_semantics_ = caller_semantics; - caller_ = CreatePeerConnectionWrapper("Caller", nullptr, nullptr, nullptr, - nullptr); + caller_ = CreatePeerConnectionWrapper( + "Caller", nullptr, nullptr, nullptr, + webrtc::PeerConnectionDependencies(nullptr)); sdp_semantics_ = callee_semantics; - callee_ = CreatePeerConnectionWrapper("Callee", nullptr, nullptr, nullptr, - nullptr); + callee_ = CreatePeerConnectionWrapper( + "Callee", nullptr, nullptr, nullptr, + webrtc::PeerConnectionDependencies(nullptr)); sdp_semantics_ = original_semantics; return caller_ && callee_; } @@ -1203,30 +1215,51 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { bool CreatePeerConnectionWrappersWithConstraints( MediaConstraintsInterface* caller_constraints, MediaConstraintsInterface* callee_constraints) { - caller_ = CreatePeerConnectionWrapper("Caller", caller_constraints, nullptr, - nullptr, nullptr); - callee_ = CreatePeerConnectionWrapper("Callee", callee_constraints, nullptr, - nullptr, nullptr); + caller_ = CreatePeerConnectionWrapper( + "Caller", caller_constraints, nullptr, nullptr, + webrtc::PeerConnectionDependencies(nullptr)); + callee_ = CreatePeerConnectionWrapper( + "Callee", callee_constraints, nullptr, nullptr, + webrtc::PeerConnectionDependencies(nullptr)); + return caller_ && callee_; } bool CreatePeerConnectionWrappersWithConfig( const PeerConnectionInterface::RTCConfiguration& caller_config, const PeerConnectionInterface::RTCConfiguration& callee_config) { - caller_ = CreatePeerConnectionWrapper("Caller", nullptr, nullptr, - &caller_config, nullptr); - callee_ = CreatePeerConnectionWrapper("Callee", nullptr, nullptr, - &callee_config, nullptr); + caller_ = CreatePeerConnectionWrapper( + "Caller", nullptr, nullptr, &caller_config, + webrtc::PeerConnectionDependencies(nullptr)); + callee_ = CreatePeerConnectionWrapper( + "Callee", nullptr, nullptr, &callee_config, + webrtc::PeerConnectionDependencies(nullptr)); + return caller_ && callee_; + } + + bool CreatePeerConnectionWrappersWithConfigAndDeps( + const PeerConnectionInterface::RTCConfiguration& caller_config, + webrtc::PeerConnectionDependencies caller_dependencies, + const PeerConnectionInterface::RTCConfiguration& callee_config, + webrtc::PeerConnectionDependencies callee_dependencies) { + caller_ = + CreatePeerConnectionWrapper("Caller", nullptr, nullptr, &caller_config, + std::move(caller_dependencies)); + callee_ = + CreatePeerConnectionWrapper("Callee", nullptr, nullptr, &callee_config, + std::move(callee_dependencies)); return caller_ && callee_; } bool CreatePeerConnectionWrappersWithOptions( const PeerConnectionFactory::Options& caller_options, const PeerConnectionFactory::Options& callee_options) { - caller_ = CreatePeerConnectionWrapper("Caller", nullptr, &caller_options, - nullptr, nullptr); - callee_ = CreatePeerConnectionWrapper("Callee", nullptr, &callee_options, - nullptr, nullptr); + caller_ = CreatePeerConnectionWrapper( + "Caller", nullptr, &caller_options, nullptr, + webrtc::PeerConnectionDependencies(nullptr)); + callee_ = CreatePeerConnectionWrapper( + "Callee", nullptr, &callee_options, nullptr, + webrtc::PeerConnectionDependencies(nullptr)); return caller_ && callee_; } @@ -1236,8 +1269,10 @@ class PeerConnectionIntegrationBaseTest : public testing::Test { new FakeRTCCertificateGenerator()); cert_generator->use_alternate_key(); + webrtc::PeerConnectionDependencies dependencies(nullptr); + dependencies.cert_generator = std::move(cert_generator); return CreatePeerConnectionWrapper("New Peer", nullptr, nullptr, nullptr, - std::move(cert_generator)); + std::move(dependencies)); } // Once called, SDP blobs and ICE candidates will be automatically signaled @@ -3925,6 +3960,150 @@ TEST_P(PeerConnectionIntegrationTest, TurnCustomizerUsedForTurnConnections) { delete SetCalleePcWrapperAndReturnCurrent(nullptr); } +// Verify that a SSLCertificateVerifier passed in through +// PeerConnectionDependencies is actually used by the underlying SSL +// implementation to determine whether a certificate presented by the TURN +// server is accepted by the client. Note that openssladapter_unittest.cc +// contains more detailed, lower-level tests. +TEST_P(PeerConnectionIntegrationTest, + SSLCertificateVerifierUsedForTurnConnections) { + static const rtc::SocketAddress turn_server_internal_address{"88.88.88.0", + 3478}; + static const rtc::SocketAddress turn_server_external_address{"88.88.88.1", 0}; + + // Enable TCP-TLS for the fake turn server. We need to pass in 88.88.88.0 so + // that host name verification passes on the fake certificate. + cricket::TestTurnServer turn_server( + network_thread(), turn_server_internal_address, + turn_server_external_address, cricket::PROTO_TLS, + /*ignore_bad_certs=*/true, "88.88.88.0"); + + webrtc::PeerConnectionInterface::IceServer ice_server; + ice_server.urls.push_back("turns:88.88.88.0:3478?transport=tcp"); + ice_server.username = "test"; + ice_server.password = "test"; + + PeerConnectionInterface::RTCConfiguration client_1_config; + client_1_config.servers.push_back(ice_server); + client_1_config.type = webrtc::PeerConnectionInterface::kRelay; + + PeerConnectionInterface::RTCConfiguration client_2_config; + client_2_config.servers.push_back(ice_server); + // Setting the type to kRelay forces the connection to go through a TURN + // server. + client_2_config.type = webrtc::PeerConnectionInterface::kRelay; + + // Get a copy to the pointer so we can verify calls later. + rtc::TestCertificateVerifier* client_1_cert_verifier = + new rtc::TestCertificateVerifier(); + client_1_cert_verifier->verify_certificate_ = true; + rtc::TestCertificateVerifier* client_2_cert_verifier = + new rtc::TestCertificateVerifier(); + client_2_cert_verifier->verify_certificate_ = true; + + // Create the dependencies with the test certificate verifier. + webrtc::PeerConnectionDependencies client_1_deps(nullptr); + client_1_deps.tls_cert_verifier = + std::unique_ptr(client_1_cert_verifier); + webrtc::PeerConnectionDependencies client_2_deps(nullptr); + client_2_deps.tls_cert_verifier = + std::unique_ptr(client_2_cert_verifier); + + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndDeps( + client_1_config, std::move(client_1_deps), client_2_config, + std::move(client_2_deps))); + ConnectFakeSignaling(); + + // Set "offer to receive audio/video" without adding any tracks, so we just + // set up ICE/DTLS with no media. + PeerConnectionInterface::RTCOfferAnswerOptions options; + options.offer_to_receive_audio = 1; + options.offer_to_receive_video = 1; + caller()->SetOfferAnswerOptions(options); + caller()->CreateAndSetAndSignalOffer(); + ASSERT_TRUE_WAIT(DtlsConnected(), kDefaultTimeout); + + EXPECT_GT(client_1_cert_verifier->call_count_, 0u); + EXPECT_GT(client_2_cert_verifier->call_count_, 0u); + + // Need to free the clients here since they're using things we created on + // the stack. + delete SetCallerPcWrapperAndReturnCurrent(nullptr); + delete SetCalleePcWrapperAndReturnCurrent(nullptr); +} + +TEST_P(PeerConnectionIntegrationTest, + SSLCertificateVerifierFailureUsedForTurnConnectionsFailsConnection) { + static const rtc::SocketAddress turn_server_internal_address{"88.88.88.0", + 3478}; + static const rtc::SocketAddress turn_server_external_address{"88.88.88.1", 0}; + + // Enable TCP-TLS for the fake turn server. We need to pass in 88.88.88.0 so + // that host name verification passes on the fake certificate. + cricket::TestTurnServer turn_server( + network_thread(), turn_server_internal_address, + turn_server_external_address, cricket::PROTO_TLS, + /*ignore_bad_certs=*/true, "88.88.88.0"); + + webrtc::PeerConnectionInterface::IceServer ice_server; + ice_server.urls.push_back("turns:88.88.88.0:3478?transport=tcp"); + ice_server.username = "test"; + ice_server.password = "test"; + + PeerConnectionInterface::RTCConfiguration client_1_config; + client_1_config.servers.push_back(ice_server); + client_1_config.type = webrtc::PeerConnectionInterface::kRelay; + + PeerConnectionInterface::RTCConfiguration client_2_config; + client_2_config.servers.push_back(ice_server); + // Setting the type to kRelay forces the connection to go through a TURN + // server. + client_2_config.type = webrtc::PeerConnectionInterface::kRelay; + + // Get a copy to the pointer so we can verify calls later. + rtc::TestCertificateVerifier* client_1_cert_verifier = + new rtc::TestCertificateVerifier(); + client_1_cert_verifier->verify_certificate_ = false; + rtc::TestCertificateVerifier* client_2_cert_verifier = + new rtc::TestCertificateVerifier(); + client_2_cert_verifier->verify_certificate_ = false; + + // Create the dependencies with the test certificate verifier. + webrtc::PeerConnectionDependencies client_1_deps(nullptr); + client_1_deps.tls_cert_verifier = + std::unique_ptr(client_1_cert_verifier); + webrtc::PeerConnectionDependencies client_2_deps(nullptr); + client_2_deps.tls_cert_verifier = + std::unique_ptr(client_2_cert_verifier); + + ASSERT_TRUE(CreatePeerConnectionWrappersWithConfigAndDeps( + client_1_config, std::move(client_1_deps), client_2_config, + std::move(client_2_deps))); + ConnectFakeSignaling(); + + // Set "offer to receive audio/video" without adding any tracks, so we just + // set up ICE/DTLS with no media. + PeerConnectionInterface::RTCOfferAnswerOptions options; + options.offer_to_receive_audio = 1; + options.offer_to_receive_video = 1; + caller()->SetOfferAnswerOptions(options); + caller()->CreateAndSetAndSignalOffer(); + bool wait_res = true; + // TODO(bugs.webrtc.org/9219): When IceConnectionState is implemented + // properly, should be able to just wait for a state of "failed" instead of + // waiting a fixed 10 seconds. + WAIT_(DtlsConnected(), kDefaultTimeout, wait_res); + ASSERT_FALSE(wait_res); + + EXPECT_GT(client_1_cert_verifier->call_count_, 0u); + EXPECT_GT(client_2_cert_verifier->call_count_, 0u); + + // Need to free the clients here since they're using things we created on + // the stack. + delete SetCallerPcWrapperAndReturnCurrent(nullptr); + delete SetCalleePcWrapperAndReturnCurrent(nullptr); +} + // Test that audio and video flow end-to-end when codec names don't use the // expected casing, given that they're supposed to be case insensitive. To test // this, all but one codec is removed from each media description, and its diff --git a/rtc_base/BUILD.gn b/rtc_base/BUILD.gn index 9ae175a1b33..cd8d52c5366 100644 --- a/rtc_base/BUILD.gn +++ b/rtc_base/BUILD.gn @@ -782,8 +782,8 @@ rtc_static_library("rtc_base_generic") { "openssl.h", "openssladapter.cc", "openssladapter.h", - "opensslcommon.cc", - "opensslcommon.h", + "opensslcertificate.cc", + "opensslcertificate.h", "openssldigest.cc", "openssldigest.h", "opensslidentity.cc", @@ -792,6 +792,8 @@ rtc_static_library("rtc_base_generic") { "opensslsessioncache.h", "opensslstreamadapter.cc", "opensslstreamadapter.h", + "opensslutility.cc", + "opensslutility.h", "physicalsocketserver.cc", "physicalsocketserver.h", "proxyinfo.cc", @@ -819,6 +821,8 @@ rtc_static_library("rtc_base_generic") { "socketstream.h", "ssladapter.cc", "ssladapter.h", + "sslcertificate.cc", + "sslcertificate.h", "sslfingerprint.cc", "sslfingerprint.h", "sslidentity.cc", @@ -881,6 +885,10 @@ rtc_static_library("rtc_base_generic") { configs += [ ":external_ssl_library" ] } + if (rtc_builtin_ssl_root_certificates) { + defines += [ "WEBRTC_ENABLE_BUILT_IN_SSL_ROOT_CERTIFICATES" ] + } + if (is_android) { sources += [ "ifaddrs-android.cc", @@ -1021,6 +1029,7 @@ rtc_source_set("rtc_base_tests_utils") { "sigslottester.h", "sigslottester.h.pump", "testbase64.h", + "testcertificateverifier.h", "testclient.cc", "testclient.h", "testechoserver.cc", @@ -1255,6 +1264,7 @@ if (rtc_include_tests) { } rtc_source_set("rtc_base_unittests") { testonly = true + defines = [] sources = [ "callback_unittest.cc", @@ -1292,8 +1302,8 @@ if (rtc_include_tests) { if (is_posix || is_fuchsia) { sources += [ "openssladapter_unittest.cc", - "opensslcommon_unittest.cc", "opensslsessioncache_unittest.cc", + "opensslutility_unittest.cc", "ssladapter_unittest.cc", "sslidentity_unittest.cc", "sslstreamadapter_unittest.cc", @@ -1321,6 +1331,10 @@ if (rtc_include_tests) { } else { configs += [ ":external_ssl_library" ] } + + if (!rtc_builtin_ssl_root_certificates) { + defines += [ "WEBRTC_DISABLE_BUILT_IN_SSL_ROOT_CERTIFICATES" ] + } } } diff --git a/rtc_base/openssladapter.cc b/rtc_base/openssladapter.cc index 03b3ca8c62a..87ac744c3a5 100644 --- a/rtc_base/openssladapter.cc +++ b/rtc_base/openssladapter.cc @@ -23,20 +23,17 @@ #include #include "rtc_base/openssl.h" -#include "rtc_base/arraysize.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/numerics/safe_conversions.h" -#include "rtc_base/opensslcommon.h" -#include "rtc_base/ptr_util.h" -#include "rtc_base/sslroots.h" +#include "rtc_base/opensslutility.h" #include "rtc_base/stringencode.h" #include "rtc_base/stringutils.h" #include "rtc_base/thread.h" #ifndef OPENSSL_IS_BORINGSSL -// TODO: Use a nicer abstraction for mutex. +// TODO(benwright): Use a nicer abstraction for mutex. #if defined(WEBRTC_WIN) #define MUTEX_TYPE HANDLE @@ -69,7 +66,7 @@ struct CRYPTO_dynlock_value { static int socket_write(BIO* h, const char* buf, int num); static int socket_read(BIO* h, char* buf, int size); static int socket_puts(BIO* h, const char* str); -static long socket_ctrl(BIO* h, int cmd, long arg1, void* arg2); +static long socket_ctrl(BIO* h, int cmd, long arg1, void* arg2); // NOLINT static int socket_new(BIO* h); static int socket_free(BIO* data); @@ -141,7 +138,7 @@ static int socket_puts(BIO* b, const char* str) { return socket_write(b, str, rtc::checked_cast(strlen(str))); } -static long socket_ctrl(BIO* b, int cmd, long num, void* ptr) { +static long socket_ctrl(BIO* b, int cmd, long num, void* ptr) { // NOLINT switch (cmd) { case BIO_CTRL_RESET: return 0; @@ -181,9 +178,7 @@ static void LogSslError() { namespace rtc { -VerificationCallback OpenSSLAdapter::custom_verify_callback_ = nullptr; - -bool OpenSSLAdapter::InitializeSSL(VerificationCallback callback) { +bool OpenSSLAdapter::InitializeSSL() { if (!SSL_library_init()) return false; #if !defined(ADDRESS_SANITIZER) || !defined(WEBRTC_MAC) || defined(WEBRTC_IOS) @@ -193,7 +188,6 @@ bool OpenSSLAdapter::InitializeSSL(VerificationCallback callback) { ERR_load_BIO_strings(); OpenSSL_add_all_algorithms(); RAND_poll(); - custom_verify_callback_ = callback; return true; } @@ -202,9 +196,11 @@ bool OpenSSLAdapter::CleanupSSL() { } OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket, - OpenSSLSessionCache* ssl_session_cache) + OpenSSLSessionCache* ssl_session_cache, + SSLCertificateVerifier* ssl_cert_verifier) : SSLAdapter(socket), ssl_session_cache_(ssl_session_cache), + ssl_cert_verifier_(ssl_cert_verifier), state_(SSL_NONE), role_(SSL_CLIENT), ssl_read_needs_write_(false), @@ -214,7 +210,7 @@ OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket, ssl_ctx_(nullptr), ssl_mode_(SSL_MODE_TLS), ignore_bad_cert_(false), - custom_verification_succeeded_(false) { + custom_cert_verifier_status_(false) { // If a factory is used, take a reference on the factory's SSL_CTX. // Otherwise, we'll create our own later. // Either way, we'll release our reference via SSL_CTX_free() in Cleanup(). @@ -248,6 +244,12 @@ void OpenSSLAdapter::SetMode(SSLMode mode) { ssl_mode_ = mode; } +void OpenSSLAdapter::SetCertVerifier( + SSLCertificateVerifier* ssl_cert_verifier) { + RTC_DCHECK(!ssl_ctx_); + ssl_cert_verifier_ = ssl_cert_verifier; +} + void OpenSSLAdapter::SetIdentity(SSLIdentity* identity) { RTC_DCHECK(!identity_); identity_.reset(static_cast(identity)); @@ -307,6 +309,7 @@ int OpenSSLAdapter::BeginSSL() { RTC_DCHECK(!ssl_ctx_); ssl_ctx_ = CreateContext(ssl_mode_, false); } + if (!ssl_ctx_) { err = -1; goto ssl_error; @@ -421,7 +424,7 @@ int OpenSSLAdapter::ContinueSSL() { state_ = SSL_CONNECTED; AsyncSocketAdapter::OnConnectEvent(this); -#if 0 // TODO: worry about this +#if 0 // TODO(benwright): worry about this // Don't let ourselves go away during the callbacks PRefPtr lock(this); RTC_LOG(LS_INFO) << " -- onStreamReadable"; @@ -469,7 +472,7 @@ void OpenSSLAdapter::Cleanup() { state_ = SSL_NONE; ssl_read_needs_write_ = false; ssl_write_needs_read_ = false; - custom_verification_succeeded_ = false; + custom_cert_verifier_status_ = false; pending_data_.Clear(); if (ssl_) { @@ -685,7 +688,7 @@ int OpenSSLAdapter::Close() { } Socket::ConnState OpenSSLAdapter::GetState() const { - //if (signal_close_) + // if (signal_close_) // return CS_CONNECTED; ConnState state = socket_->GetState(); if ((state == CS_CONNECTED) @@ -737,7 +740,7 @@ void OpenSSLAdapter::OnReadEvent(AsyncSocket* socket) { return; // Don't let ourselves go away during the callbacks - //PRefPtr lock(this); // TODO: fix this + // PRefPtr lock(this); // TODO(benwright): fix this if (ssl_write_needs_read_) { AsyncSocketAdapter::OnWriteEvent(socket); } @@ -762,7 +765,7 @@ void OpenSSLAdapter::OnWriteEvent(AsyncSocket* socket) { return; // Don't let ourselves go away during the callbacks - //PRefPtr lock(this); // TODO: fix this + // PRefPtr lock(this); // TODO(benwright): fix this if (ssl_read_needs_write_) { AsyncSocketAdapter::OnReadEvent(socket); @@ -787,12 +790,12 @@ void OpenSSLAdapter::OnCloseEvent(AsyncSocket* socket, int err) { } bool OpenSSLAdapter::SSLPostConnectionCheck(SSL* ssl, const std::string& host) { - bool is_valid_cert_name = openssl::VerifyPeerCertMatchesHost(ssl, host) && - (SSL_get_verify_result(ssl) == X509_V_OK || - custom_verification_succeeded_); + bool is_valid_cert_name = + openssl::VerifyPeerCertMatchesHost(ssl, host) && + (SSL_get_verify_result(ssl) == X509_V_OK || custom_cert_verifier_status_); if (!is_valid_cert_name && ignore_bad_cert_) { - RTC_DLOG(LS_WARNING) << "Other TLS post connection checks failed." + RTC_DLOG(LS_WARNING) << "Other TLS post connection checks failed. " "ignore_bad_cert_ set to true. Overriding name " "verification failure!"; is_valid_cert_name = true; @@ -847,7 +850,6 @@ int OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { << X509_verify_cert_error_string(err); } #endif - // Get our stream pointer from the store SSL* ssl = reinterpret_cast( X509_STORE_CTX_get_ex_data(store, @@ -856,13 +858,15 @@ int OpenSSLAdapter::SSLVerifyCallback(int ok, X509_STORE_CTX* store) { OpenSSLAdapter* stream = reinterpret_cast(SSL_get_app_data(ssl)); - if (!ok && custom_verify_callback_) { - void* cert = - reinterpret_cast(X509_STORE_CTX_get_current_cert(store)); - if (custom_verify_callback_(cert)) { - stream->custom_verification_succeeded_ = true; - RTC_LOG(LS_INFO) << "validated certificate using custom callback"; + if (!ok && stream->ssl_cert_verifier_ != nullptr) { + RTC_LOG(LS_INFO) << "Invoking SSL Verify Callback."; + const OpenSSLCertificate cert(X509_STORE_CTX_get_current_cert(store)); + if (stream->ssl_cert_verifier_->Verify(cert)) { + stream->custom_cert_verifier_status_ = true; + RTC_LOG(LS_INFO) << "Validated certificate using custom callback"; ok = true; + } else { + RTC_LOG(LS_INFO) << "Failed to verify certificate using custom callback"; } } @@ -884,27 +888,6 @@ int OpenSSLAdapter::NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session) { return 1; // We've taken ownership of the session; OpenSSL shouldn't free it. } -bool OpenSSLAdapter::ConfigureTrustedRootCertificates(SSL_CTX* ctx) { - // Add the root cert that we care about to the SSL context - int count_of_added_certs = 0; - for (size_t i = 0; i < arraysize(kSSLCertCertificateList); i++) { - const unsigned char* cert_buffer = kSSLCertCertificateList[i]; - size_t cert_buffer_len = kSSLCertCertificateSizeList[i]; - X509* cert = - d2i_X509(nullptr, &cert_buffer, checked_cast(cert_buffer_len)); - if (cert) { - int return_value = X509_STORE_add_cert(SSL_CTX_get_cert_store(ctx), cert); - if (return_value == 0) { - RTC_LOG(LS_WARNING) << "Unable to add certificate."; - } else { - count_of_added_certs++; - } - X509_free(cert); - } - } - return count_of_added_certs > 0; -} - SSL_CTX* OpenSSLAdapter::CreateContext(SSLMode mode, bool enable_cache) { // Use (D)TLS 1.2. // Note: BoringSSL supports a range of versions by setting max/min version @@ -924,10 +907,15 @@ SSL_CTX* OpenSSLAdapter::CreateContext(SSLMode mode, bool enable_cache) { << "(error=" << error << ')'; return nullptr; } - if (!ConfigureTrustedRootCertificates(ctx)) { + +#ifndef WEBRTC_DISABLE_BUILT_IN_SSL_ROOT_CERTIFICATES + if (!openssl::LoadBuiltinSSLRootCertificates(ctx)) { + RTC_LOG(LS_ERROR) << "SSL_CTX creation failed: Failed to load any trusted " + "ssl root certificates."; SSL_CTX_free(ctx); return nullptr; } +#endif // WEBRTC_DISABLE_BUILT_IN_SSL_ROOT_CERTIFICATES #if !defined(NDEBUG) SSL_CTX_set_info_callback(ctx, SSLInfoCallback); @@ -980,6 +968,7 @@ std::string TransformAlpnProtocols( ////////////////////////////////////////////////////////////////////// OpenSSLAdapterFactory::OpenSSLAdapterFactory() = default; + OpenSSLAdapterFactory::~OpenSSLAdapterFactory() = default; void OpenSSLAdapterFactory::SetMode(SSLMode mode) { @@ -987,10 +976,15 @@ void OpenSSLAdapterFactory::SetMode(SSLMode mode) { ssl_mode_ = mode; } +void OpenSSLAdapterFactory::SetCertVerifier( + SSLCertificateVerifier* ssl_cert_verifier) { + RTC_DCHECK(!ssl_session_cache_); + ssl_cert_verifier_ = ssl_cert_verifier; +} + OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(AsyncSocket* socket) { if (ssl_session_cache_ == nullptr) { - SSL_CTX* ssl_ctx = - OpenSSLAdapter::CreateContext(ssl_mode_, /* enable_cache = */ true); + SSL_CTX* ssl_ctx = OpenSSLAdapter::CreateContext(ssl_mode_, true); if (ssl_ctx == nullptr) { return nullptr; } @@ -998,7 +992,8 @@ OpenSSLAdapter* OpenSSLAdapterFactory::CreateAdapter(AsyncSocket* socket) { ssl_session_cache_ = MakeUnique(ssl_mode_, ssl_ctx); SSL_CTX_free(ssl_ctx); } - return new OpenSSLAdapter(socket, ssl_session_cache_.get()); + return new OpenSSLAdapter(socket, ssl_session_cache_.get(), + ssl_cert_verifier_); } } // namespace rtc diff --git a/rtc_base/openssladapter.h b/rtc_base/openssladapter.h index 5f5eb80c6e7..0de528c0a5c 100644 --- a/rtc_base/openssladapter.h +++ b/rtc_base/openssladapter.h @@ -21,26 +21,35 @@ #include "rtc_base/buffer.h" #include "rtc_base/messagehandler.h" #include "rtc_base/messagequeue.h" +#include "rtc_base/opensslcertificate.h" #include "rtc_base/opensslidentity.h" #include "rtc_base/opensslsessioncache.h" +#include "rtc_base/ptr_util.h" #include "rtc_base/ssladapter.h" namespace rtc { class OpenSSLAdapter : public SSLAdapter, public MessageHandler { public: - static bool InitializeSSL(VerificationCallback callback); + static bool InitializeSSL(); static bool CleanupSSL(); + // Creating an OpenSSLAdapter requires a socket to bind to, an optional + // session cache if you wish to improve performance by caching sessions for + // hostnames you have previously connected to and an optional + // SSLCertificateVerifier which can override any existing trusted roots to + // validate a peer certificate. The cache and verifier are effectively + // immutable after the the SSL connection starts. explicit OpenSSLAdapter(AsyncSocket* socket, - OpenSSLSessionCache* ssl_session_cache = nullptr); + OpenSSLSessionCache* ssl_session_cache = nullptr, + SSLCertificateVerifier* ssl_cert_verifier = nullptr); ~OpenSSLAdapter() override; void SetIgnoreBadCert(bool ignore) override; void SetAlpnProtocols(const std::vector& protos) override; void SetEllipticCurves(const std::vector& curves) override; - void SetMode(SSLMode mode) override; + void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override; void SetIdentity(SSLIdentity* identity) override; void SetRole(SSLRole role) override; AsyncSocket* Accept(SocketAddress* paddr) override; @@ -53,11 +62,9 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { SocketAddress* paddr, int64_t* timestamp) override; int Close() override; - // Note that the socket returns ST_CONNECTING while SSL is being negotiated. ConnState GetState() const override; bool IsResumedSession() override; - // Creates a new SSL_CTX object, configured for client-to-server usage // with SSLMode |mode|, and if |enable_cache| is true, with support for // storing successful sessions so that they can be later resumed. @@ -87,9 +94,7 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { // Return value and arguments have the same meanings as for Send; |error| is // an output parameter filled with the result of SSL_get_error. int DoSslWrite(const void* pv, size_t cb, int* error); - void OnMessage(Message* msg) override; - bool SSLPostConnectionCheck(SSL* ssl, const std::string& host); #if !defined(NDEBUG) @@ -97,7 +102,6 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { static void SSLInfoCallback(const SSL* ssl, int where, int ret); #endif static int SSLVerifyCallback(int ok, X509_STORE_CTX* store); - static VerificationCallback custom_verify_callback_; friend class OpenSSLStreamAdapter; // for custom_verify_callback_; // If the SSL_CTX was created with |enable_cache| set to true, this callback @@ -105,30 +109,30 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { // to allow its SSL_SESSION* to be cached for later resumption. static int NewSSLSessionCallback(SSL* ssl, SSL_SESSION* session); - static bool ConfigureTrustedRootCertificates(SSL_CTX* ctx); - - // Parent object that maintains shared state. - // Can be null if state sharing is not needed. + // Optional SSL Shared session cache to improve performance. OpenSSLSessionCache* ssl_session_cache_ = nullptr; - + // Optional SSL Certificate verifier which can be set by a third party. + SSLCertificateVerifier* ssl_cert_verifier_ = nullptr; + // The current connection state of the (d)TLS connection. SSLState state_; std::unique_ptr identity_; + // Indicates whethere this is a client or a server. SSLRole role_; bool ssl_read_needs_write_; bool ssl_write_needs_read_; // If true, socket will retain SSL configuration after Close. // TODO(juberti): Remove this unused flag. bool restartable_; - // This buffer is used if SSL_write fails with SSL_ERROR_WANT_WRITE, which // means we need to keep retrying with *the same exact data* until it // succeeds. Afterwards it will be cleared. Buffer pending_data_; - SSL* ssl_; + // Holds the SSL context, which may be shared if an session cache is provided. SSL_CTX* ssl_ctx_; + // Hostname of server that is being connected, used for SNI. std::string ssl_host_name_; - // Do DTLS or not + // Set the adapter to DTLS or TLS mode before creating the context. SSLMode ssl_mode_; // If true, the server certificate need not match the configured hostname. bool ignore_bad_cert_; @@ -136,14 +140,10 @@ class OpenSSLAdapter : public SSLAdapter, public MessageHandler { std::vector alpn_protocols_; // List of elliptic curves to be used in the TLS elliptic curves extension. std::vector elliptic_curves_; - - bool custom_verification_succeeded_; + // Holds the result of the call to run of the ssl_cert_verify_->Verify() + bool custom_cert_verifier_status_; }; -std::string TransformAlpnProtocols(const std::vector& protos); - -///////////////////////////////////////////////////////////////////////////// - // The OpenSSLAdapterFactory is responsbile for creating multiple new // OpenSSLAdapters with a shared SSL_CTX and a shared SSL_SESSION cache. The // SSL_SESSION cache allows existing SSL_SESSIONS to be reused instead of @@ -156,6 +156,10 @@ class OpenSSLAdapterFactory : public SSLAdapterFactory { // the first adapter is created with the factory. If it is called after it // will DCHECK. void SetMode(SSLMode mode) override; + // Set a custom certificate verifier to be passed down to each instance + // created with this factory. This should only ever be set before the first + // call to the factory and cannot be changed after the fact. + void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) override; // Constructs a new socket using the shared OpenSSLSessionCache. This means // existing SSLSessions already in the cache will be reused instead of // re-created for improved performance. @@ -166,11 +170,16 @@ class OpenSSLAdapterFactory : public SSLAdapterFactory { SSLMode ssl_mode_ = SSL_MODE_TLS; // Holds a cache of existing SSL Sessions. std::unique_ptr ssl_session_cache_; + // Provides an optional custom callback for verifying SSL certificates, this + // in currently only used for TLS-TURN connections. + SSLCertificateVerifier* ssl_cert_verifier_ = nullptr; // TODO(benwright): Remove this when context is moved to OpenSSLCommon. // Hold a friend class to the OpenSSLAdapter to retrieve the context. friend class OpenSSLAdapter; }; +std::string TransformAlpnProtocols(const std::vector& protos); + } // namespace rtc #endif // RTC_BASE_OPENSSLADAPTER_H_ diff --git a/rtc_base/openssladapter_unittest.cc b/rtc_base/openssladapter_unittest.cc index d043353ac1f..d6f34b0873c 100644 --- a/rtc_base/openssladapter_unittest.cc +++ b/rtc_base/openssladapter_unittest.cc @@ -12,10 +12,46 @@ #include #include +#include "rtc_base/asyncsocket.h" #include "rtc_base/gunit.h" #include "rtc_base/openssladapter.h" +#include "rtc_base/ptr_util.h" +#include "test/gmock.h" namespace rtc { +namespace { + +class MockAsyncSocket : public AsyncSocket { + public: + virtual ~MockAsyncSocket() = default; + MOCK_METHOD1(Accept, AsyncSocket*(SocketAddress*)); + MOCK_CONST_METHOD0(GetLocalAddress, SocketAddress()); + MOCK_CONST_METHOD0(GetRemoteAddress, SocketAddress()); + MOCK_METHOD1(Bind, int(const SocketAddress&)); + MOCK_METHOD1(Connect, int(const SocketAddress&)); + MOCK_METHOD2(Send, int(const void*, size_t)); + MOCK_METHOD3(SendTo, int(const void*, size_t, const SocketAddress&)); + MOCK_METHOD3(Recv, int(void*, size_t, int64_t*)); + MOCK_METHOD4(RecvFrom, int(void*, size_t, SocketAddress*, int64_t*)); + MOCK_METHOD1(Listen, int(int)); + MOCK_METHOD0(Close, int()); + MOCK_CONST_METHOD0(GetError, int()); + MOCK_METHOD1(SetError, void(int)); + MOCK_CONST_METHOD0(GetState, ConnState()); + MOCK_METHOD2(GetOption, int(Option, int*)); + MOCK_METHOD2(SetOption, int(Option, int)); +}; + +class MockCertVerifier : public SSLCertificateVerifier { + public: + virtual ~MockCertVerifier() = default; + MOCK_METHOD1(Verify, bool(const SSLCertificate&)); +}; + +} // namespace + +using ::testing::_; +using ::testing::Return; TEST(OpenSSLAdapterTest, TestTransformAlpnProtocols) { EXPECT_EQ("", TransformAlpnProtocols(std::vector())); @@ -38,4 +74,36 @@ TEST(OpenSSLAdapterTest, TestTransformAlpnProtocols) { EXPECT_EQ(expected_response.str(), TransformAlpnProtocols(alpn_protos)); } +// Verifies that SSLStart works when OpenSSLAdapter is started in standalone +// mode. +TEST(OpenSSLAdapterTest, TestBeginSSLBeforeConnection) { + AsyncSocket* async_socket = new MockAsyncSocket(); + OpenSSLAdapter adapter(async_socket); + EXPECT_EQ(adapter.StartSSL("webrtc.org", false), 0); +} + +// Verifies that the adapter factory can create new adapters. +TEST(OpenSSLAdapterFactoryTest, CreateSingleOpenSSLAdapter) { + OpenSSLAdapterFactory adapter_factory; + AsyncSocket* async_socket = new MockAsyncSocket(); + auto simple_adapter = std::unique_ptr( + adapter_factory.CreateAdapter(async_socket)); + EXPECT_NE(simple_adapter, nullptr); +} + +// Verifies that setting a custom verifier still allows for adapters to be +// created. +TEST(OpenSSLAdapterFactoryTest, CreateWorksWithCustomVerifier) { + MockCertVerifier* mock_verifier = new MockCertVerifier(); + EXPECT_CALL(*mock_verifier, Verify(_)).WillRepeatedly(Return(true)); + auto cert_verifier = std::unique_ptr(mock_verifier); + + OpenSSLAdapterFactory adapter_factory; + adapter_factory.SetCertVerifier(cert_verifier.get()); + AsyncSocket* async_socket = new MockAsyncSocket(); + auto simple_adapter = std::unique_ptr( + adapter_factory.CreateAdapter(async_socket)); + EXPECT_NE(simple_adapter, nullptr); +} + } // namespace rtc diff --git a/rtc_base/opensslcertificate.cc b/rtc_base/opensslcertificate.cc new file mode 100644 index 00000000000..005e96fbac8 --- /dev/null +++ b/rtc_base/opensslcertificate.cc @@ -0,0 +1,321 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/opensslcertificate.h" + +#include +#include +#include + +#if defined(WEBRTC_WIN) +// Must be included first before openssl headers. +#include "rtc_base/win32.h" // NOLINT +#endif // WEBRTC_WIN + +#include +#include +#include +#include +#include +#include + +#include "rtc_base/arraysize.h" +#include "rtc_base/checks.h" +#include "rtc_base/helpers.h" +#include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" +#include "rtc_base/openssl.h" +#include "rtc_base/openssldigest.h" +#include "rtc_base/opensslidentity.h" +#include "rtc_base/opensslutility.h" +#include "rtc_base/ptr_util.h" +#ifndef WEBRTC_DISABLE_BUILT_IN_SSL_ROOT_CERTIFICATES +#include "rtc_base/sslroots.h" +#endif + +namespace rtc { + +////////////////////////////////////////////////////////////////////// +// OpenSSLCertificate +////////////////////////////////////////////////////////////////////// + +// We could have exposed a myriad of parameters for the crypto stuff, +// but keeping it simple seems best. + +// Random bits for certificate serial number +static const int SERIAL_RAND_BITS = 64; + +// Generate a self-signed certificate, with the public key from the +// given key pair. Caller is responsible for freeing the returned object. +static X509* MakeCertificate(EVP_PKEY* pkey, const SSLIdentityParams& params) { + RTC_LOG(LS_INFO) << "Making certificate for " << params.common_name; + X509* x509 = nullptr; + BIGNUM* serial_number = nullptr; + X509_NAME* name = nullptr; + time_t epoch_off = 0; // Time offset since epoch. + + if ((x509 = X509_new()) == nullptr) + goto error; + + if (!X509_set_pubkey(x509, pkey)) + goto error; + + // serial number + // temporary reference to serial number inside x509 struct + ASN1_INTEGER* asn1_serial_number; + if ((serial_number = BN_new()) == nullptr || + !BN_pseudo_rand(serial_number, SERIAL_RAND_BITS, 0, 0) || + (asn1_serial_number = X509_get_serialNumber(x509)) == nullptr || + !BN_to_ASN1_INTEGER(serial_number, asn1_serial_number)) + goto error; + + if (!X509_set_version(x509, 2L)) // version 3 + goto error; + + // There are a lot of possible components for the name entries. In + // our P2P SSL mode however, the certificates are pre-exchanged + // (through the secure XMPP channel), and so the certificate + // identification is arbitrary. It can't be empty, so we set some + // arbitrary common_name. Note that this certificate goes out in + // clear during SSL negotiation, so there may be a privacy issue in + // putting anything recognizable here. + if ((name = X509_NAME_new()) == nullptr || + !X509_NAME_add_entry_by_NID(name, NID_commonName, MBSTRING_UTF8, + (unsigned char*)params.common_name.c_str(), + -1, -1, 0) || + !X509_set_subject_name(x509, name) || !X509_set_issuer_name(x509, name)) + goto error; + + if (!X509_time_adj(X509_get_notBefore(x509), params.not_before, &epoch_off) || + !X509_time_adj(X509_get_notAfter(x509), params.not_after, &epoch_off)) + goto error; + + if (!X509_sign(x509, pkey, EVP_sha256())) + goto error; + + BN_free(serial_number); + X509_NAME_free(name); + RTC_LOG(LS_INFO) << "Returning certificate"; + return x509; + +error: + BN_free(serial_number); + X509_NAME_free(name); + X509_free(x509); + return nullptr; +} + +#if !defined(NDEBUG) +// Print a certificate to the log, for debugging. +static void PrintCert(X509* x509) { + BIO* temp_memory_bio = BIO_new(BIO_s_mem()); + if (!temp_memory_bio) { + RTC_DLOG_F(LS_ERROR) << "Failed to allocate temporary memory bio"; + return; + } + X509_print_ex(temp_memory_bio, x509, XN_FLAG_SEP_CPLUS_SPC, 0); + BIO_write(temp_memory_bio, "\0", 1); + char* buffer; + BIO_get_mem_data(temp_memory_bio, &buffer); + RTC_DLOG(LS_VERBOSE) << buffer; + BIO_free(temp_memory_bio); +} +#endif + +OpenSSLCertificate::OpenSSLCertificate(X509* x509) : x509_(x509) { + AddReference(); +} + +OpenSSLCertificate* OpenSSLCertificate::Generate( + OpenSSLKeyPair* key_pair, + const SSLIdentityParams& params) { + SSLIdentityParams actual_params(params); + if (actual_params.common_name.empty()) { + // Use a random string, arbitrarily 8chars long. + actual_params.common_name = CreateRandomString(8); + } + X509* x509 = MakeCertificate(key_pair->pkey(), actual_params); + if (!x509) { + openssl::LogSSLErrors("Generating certificate"); + return nullptr; + } +#if !defined(NDEBUG) + PrintCert(x509); +#endif + OpenSSLCertificate* ret = new OpenSSLCertificate(x509); + X509_free(x509); + return ret; +} + +OpenSSLCertificate* OpenSSLCertificate::FromPEMString( + const std::string& pem_string) { + BIO* bio = BIO_new_mem_buf(const_cast(pem_string.c_str()), -1); + if (!bio) + return nullptr; + BIO_set_mem_eof_return(bio, 0); + X509* x509 = + PEM_read_bio_X509(bio, nullptr, nullptr, const_cast("\0")); + BIO_free(bio); // Frees the BIO, but not the pointed-to string. + + if (!x509) + return nullptr; + + OpenSSLCertificate* ret = new OpenSSLCertificate(x509); + X509_free(x509); + return ret; +} + +// NOTE: This implementation only functions correctly after InitializeSSL +// and before CleanupSSL. +bool OpenSSLCertificate::GetSignatureDigestAlgorithm( + std::string* algorithm) const { + int nid = X509_get_signature_nid(x509_); + switch (nid) { + case NID_md5WithRSA: + case NID_md5WithRSAEncryption: + *algorithm = DIGEST_MD5; + break; + case NID_ecdsa_with_SHA1: + case NID_dsaWithSHA1: + case NID_dsaWithSHA1_2: + case NID_sha1WithRSA: + case NID_sha1WithRSAEncryption: + *algorithm = DIGEST_SHA_1; + break; + case NID_ecdsa_with_SHA224: + case NID_sha224WithRSAEncryption: + case NID_dsa_with_SHA224: + *algorithm = DIGEST_SHA_224; + break; + case NID_ecdsa_with_SHA256: + case NID_sha256WithRSAEncryption: + case NID_dsa_with_SHA256: + *algorithm = DIGEST_SHA_256; + break; + case NID_ecdsa_with_SHA384: + case NID_sha384WithRSAEncryption: + *algorithm = DIGEST_SHA_384; + break; + case NID_ecdsa_with_SHA512: + case NID_sha512WithRSAEncryption: + *algorithm = DIGEST_SHA_512; + break; + default: + // Unknown algorithm. There are several unhandled options that are less + // common and more complex. + RTC_LOG(LS_ERROR) << "Unknown signature algorithm NID: " << nid; + algorithm->clear(); + return false; + } + return true; +} + +bool OpenSSLCertificate::ComputeDigest(const std::string& algorithm, + unsigned char* digest, + size_t size, + size_t* length) const { + return ComputeDigest(x509_, algorithm, digest, size, length); +} + +bool OpenSSLCertificate::ComputeDigest(const X509* x509, + const std::string& algorithm, + unsigned char* digest, + size_t size, + size_t* length) { + const EVP_MD* md; + unsigned int n; + + if (!OpenSSLDigest::GetDigestEVP(algorithm, &md)) + return false; + + if (size < static_cast(EVP_MD_size(md))) + return false; + + X509_digest(x509, md, digest, &n); + + *length = n; + + return true; +} + +OpenSSLCertificate::~OpenSSLCertificate() { + X509_free(x509_); +} + +OpenSSLCertificate* OpenSSLCertificate::GetReference() const { + return new OpenSSLCertificate(x509_); +} + +std::string OpenSSLCertificate::ToPEMString() const { + BIO* bio = BIO_new(BIO_s_mem()); + if (!bio) { + FATAL() << "unreachable code"; + } + if (!PEM_write_bio_X509(bio, x509_)) { + BIO_free(bio); + FATAL() << "unreachable code"; + } + BIO_write(bio, "\0", 1); + char* buffer; + BIO_get_mem_data(bio, &buffer); + std::string ret(buffer); + BIO_free(bio); + return ret; +} + +void OpenSSLCertificate::ToDER(Buffer* der_buffer) const { + // In case of failure, make sure to leave the buffer empty. + der_buffer->SetSize(0); + + // Calculates the DER representation of the certificate, from scratch. + BIO* bio = BIO_new(BIO_s_mem()); + if (!bio) { + FATAL() << "unreachable code"; + } + if (!i2d_X509_bio(bio, x509_)) { + BIO_free(bio); + FATAL() << "unreachable code"; + } + char* data; + size_t length = BIO_get_mem_data(bio, &data); + der_buffer->SetData(data, length); + BIO_free(bio); +} + +void OpenSSLCertificate::AddReference() const { + RTC_DCHECK(x509_ != nullptr); + X509_up_ref(x509_); +} + +bool OpenSSLCertificate::operator==(const OpenSSLCertificate& other) const { + return X509_cmp(x509_, other.x509_) == 0; +} + +bool OpenSSLCertificate::operator!=(const OpenSSLCertificate& other) const { + return !(*this == other); +} + +// Documented in sslidentity.h. +int64_t OpenSSLCertificate::CertificateExpirationTime() const { + ASN1_TIME* expire_time = X509_get_notAfter(x509_); + bool long_format; + + if (expire_time->type == V_ASN1_UTCTIME) { + long_format = false; + } else if (expire_time->type == V_ASN1_GENERALIZEDTIME) { + long_format = true; + } else { + return -1; + } + + return ASN1TimeToSec(expire_time->data, expire_time->length, long_format); +} + +} // namespace rtc diff --git a/rtc_base/opensslcertificate.h b/rtc_base/opensslcertificate.h new file mode 100644 index 00000000000..c730ffd0dc9 --- /dev/null +++ b/rtc_base/opensslcertificate.h @@ -0,0 +1,80 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef RTC_BASE_OPENSSLCERTIFICATE_H_ +#define RTC_BASE_OPENSSLCERTIFICATE_H_ + +#include +#include + +#include +#include + +#include "rtc_base/checks.h" +#include "rtc_base/constructormagic.h" +#include "rtc_base/sslcertificate.h" +#include "rtc_base/sslidentity.h" + +typedef struct ssl_ctx_st SSL_CTX; + +namespace rtc { + +class OpenSSLKeyPair; + +// OpenSSLCertificate encapsulates an OpenSSL X509* certificate object, +// which is also reference counted inside the OpenSSL library. +class OpenSSLCertificate : public SSLCertificate { + public: + // X509 object has its reference count incremented. So the caller and + // OpenSSLCertificate share ownership. + explicit OpenSSLCertificate(X509* x509); + + static OpenSSLCertificate* Generate(OpenSSLKeyPair* key_pair, + const SSLIdentityParams& params); + static OpenSSLCertificate* FromPEMString(const std::string& pem_string); + + ~OpenSSLCertificate() override; + + OpenSSLCertificate* GetReference() const override; + + X509* x509() const { return x509_; } + + std::string ToPEMString() const override; + void ToDER(Buffer* der_buffer) const override; + bool operator==(const OpenSSLCertificate& other) const; + bool operator!=(const OpenSSLCertificate& other) const; + + // Compute the digest of the certificate given algorithm + bool ComputeDigest(const std::string& algorithm, + unsigned char* digest, + size_t size, + size_t* length) const override; + + // Compute the digest of a certificate as an X509 * + static bool ComputeDigest(const X509* x509, + const std::string& algorithm, + unsigned char* digest, + size_t size, + size_t* length); + + bool GetSignatureDigestAlgorithm(std::string* algorithm) const override; + + int64_t CertificateExpirationTime() const override; + + private: + void AddReference() const; + + X509* x509_; // NOT OWNED + RTC_DISALLOW_COPY_AND_ASSIGN(OpenSSLCertificate); +}; + +} // namespace rtc + +#endif // RTC_BASE_OPENSSLCERTIFICATE_H_ diff --git a/rtc_base/opensslidentity.cc b/rtc_base/opensslidentity.cc index 9f7c63b06c4..a86a757d01a 100644 --- a/rtc_base/opensslidentity.cc +++ b/rtc_base/opensslidentity.cc @@ -11,6 +11,8 @@ #include "rtc_base/opensslidentity.h" #include +#include +#include #if defined(WEBRTC_WIN) // Must be included first before openssl headers. @@ -29,6 +31,7 @@ #include "rtc_base/logging.h" #include "rtc_base/openssl.h" #include "rtc_base/openssldigest.h" +#include "rtc_base/opensslutility.h" #include "rtc_base/ptr_util.h" namespace rtc { @@ -36,9 +39,6 @@ namespace rtc { // We could have exposed a myriad of parameters for the crypto stuff, // but keeping it simple seems best. -// Random bits for certificate serial number -static const int SERIAL_RAND_BITS = 64; - // Generate a key pair. Caller is responsible for freeing the returned object. static EVP_PKEY* MakeKey(const KeyParams& key_params) { RTC_LOG(LS_INFO) << "Making key pair"; @@ -93,81 +93,10 @@ static EVP_PKEY* MakeKey(const KeyParams& key_params) { return pkey; } -// Generate a self-signed certificate, with the public key from the -// given key pair. Caller is responsible for freeing the returned object. -static X509* MakeCertificate(EVP_PKEY* pkey, const SSLIdentityParams& params) { - RTC_LOG(LS_INFO) << "Making certificate for " << params.common_name; - X509* x509 = nullptr; - BIGNUM* serial_number = nullptr; - X509_NAME* name = nullptr; - time_t epoch_off = 0; // Time offset since epoch. - - if ((x509 = X509_new()) == nullptr) - goto error; - - if (!X509_set_pubkey(x509, pkey)) - goto error; - - // serial number - // temporary reference to serial number inside x509 struct - ASN1_INTEGER* asn1_serial_number; - if ((serial_number = BN_new()) == nullptr || - !BN_pseudo_rand(serial_number, SERIAL_RAND_BITS, 0, 0) || - (asn1_serial_number = X509_get_serialNumber(x509)) == nullptr || - !BN_to_ASN1_INTEGER(serial_number, asn1_serial_number)) - goto error; - - if (!X509_set_version(x509, 2L)) // version 3 - goto error; - - // There are a lot of possible components for the name entries. In - // our P2P SSL mode however, the certificates are pre-exchanged - // (through the secure XMPP channel), and so the certificate - // identification is arbitrary. It can't be empty, so we set some - // arbitrary common_name. Note that this certificate goes out in - // clear during SSL negotiation, so there may be a privacy issue in - // putting anything recognizable here. - if ((name = X509_NAME_new()) == nullptr || - !X509_NAME_add_entry_by_NID(name, NID_commonName, MBSTRING_UTF8, - (unsigned char*)params.common_name.c_str(), - -1, -1, 0) || - !X509_set_subject_name(x509, name) || !X509_set_issuer_name(x509, name)) - goto error; - - if (!X509_time_adj(X509_get_notBefore(x509), params.not_before, &epoch_off) || - !X509_time_adj(X509_get_notAfter(x509), params.not_after, &epoch_off)) - goto error; - - if (!X509_sign(x509, pkey, EVP_sha256())) - goto error; - - BN_free(serial_number); - X509_NAME_free(name); - RTC_LOG(LS_INFO) << "Returning certificate"; - return x509; - -error: - BN_free(serial_number); - X509_NAME_free(name); - X509_free(x509); - return nullptr; -} - -// This dumps the SSL error stack to the log. -static void LogSSLErrors(const std::string& prefix) { - char error_buf[200]; - unsigned long err; - - while ((err = ERR_get_error()) != 0) { - ERR_error_string_n(err, error_buf, sizeof(error_buf)); - RTC_LOG(LS_ERROR) << prefix << ": " << error_buf << "\n"; - } -} - OpenSSLKeyPair* OpenSSLKeyPair::Generate(const KeyParams& key_params) { EVP_PKEY* pkey = MakeKey(key_params); if (!pkey) { - LogSSLErrors("Generating key pair"); + openssl::LogSSLErrors("Generating key pair"); return nullptr; } return new OpenSSLKeyPair(pkey); @@ -261,212 +190,6 @@ bool OpenSSLKeyPair::operator!=(const OpenSSLKeyPair& other) const { return !(*this == other); } -#if !defined(NDEBUG) -// Print a certificate to the log, for debugging. -static void PrintCert(X509* x509) { - BIO* temp_memory_bio = BIO_new(BIO_s_mem()); - if (!temp_memory_bio) { - RTC_DLOG_F(LS_ERROR) << "Failed to allocate temporary memory bio"; - return; - } - X509_print_ex(temp_memory_bio, x509, XN_FLAG_SEP_CPLUS_SPC, 0); - BIO_write(temp_memory_bio, "\0", 1); - char* buffer; - BIO_get_mem_data(temp_memory_bio, &buffer); - RTC_DLOG(LS_VERBOSE) << buffer; - BIO_free(temp_memory_bio); -} -#endif - -OpenSSLCertificate::OpenSSLCertificate(X509* x509) : x509_(x509) { - AddReference(); -} - -OpenSSLCertificate* OpenSSLCertificate::Generate( - OpenSSLKeyPair* key_pair, - const SSLIdentityParams& params) { - SSLIdentityParams actual_params(params); - if (actual_params.common_name.empty()) { - // Use a random string, arbitrarily 8chars long. - actual_params.common_name = CreateRandomString(8); - } - X509* x509 = MakeCertificate(key_pair->pkey(), actual_params); - if (!x509) { - LogSSLErrors("Generating certificate"); - return nullptr; - } -#if !defined(NDEBUG) - PrintCert(x509); -#endif - OpenSSLCertificate* ret = new OpenSSLCertificate(x509); - X509_free(x509); - return ret; -} - -OpenSSLCertificate* OpenSSLCertificate::FromPEMString( - const std::string& pem_string) { - BIO* bio = BIO_new_mem_buf(const_cast(pem_string.c_str()), -1); - if (!bio) - return nullptr; - BIO_set_mem_eof_return(bio, 0); - X509* x509 = - PEM_read_bio_X509(bio, nullptr, nullptr, const_cast("\0")); - BIO_free(bio); // Frees the BIO, but not the pointed-to string. - - if (!x509) - return nullptr; - - OpenSSLCertificate* ret = new OpenSSLCertificate(x509); - X509_free(x509); - return ret; -} - -// NOTE: This implementation only functions correctly after InitializeSSL -// and before CleanupSSL. -bool OpenSSLCertificate::GetSignatureDigestAlgorithm( - std::string* algorithm) const { - int nid = X509_get_signature_nid(x509_); - switch (nid) { - case NID_md5WithRSA: - case NID_md5WithRSAEncryption: - *algorithm = DIGEST_MD5; - break; - case NID_ecdsa_with_SHA1: - case NID_dsaWithSHA1: - case NID_dsaWithSHA1_2: - case NID_sha1WithRSA: - case NID_sha1WithRSAEncryption: - *algorithm = DIGEST_SHA_1; - break; - case NID_ecdsa_with_SHA224: - case NID_sha224WithRSAEncryption: - case NID_dsa_with_SHA224: - *algorithm = DIGEST_SHA_224; - break; - case NID_ecdsa_with_SHA256: - case NID_sha256WithRSAEncryption: - case NID_dsa_with_SHA256: - *algorithm = DIGEST_SHA_256; - break; - case NID_ecdsa_with_SHA384: - case NID_sha384WithRSAEncryption: - *algorithm = DIGEST_SHA_384; - break; - case NID_ecdsa_with_SHA512: - case NID_sha512WithRSAEncryption: - *algorithm = DIGEST_SHA_512; - break; - default: - // Unknown algorithm. There are several unhandled options that are less - // common and more complex. - RTC_LOG(LS_ERROR) << "Unknown signature algorithm NID: " << nid; - algorithm->clear(); - return false; - } - return true; -} - -bool OpenSSLCertificate::ComputeDigest(const std::string& algorithm, - unsigned char* digest, - size_t size, - size_t* length) const { - return ComputeDigest(x509_, algorithm, digest, size, length); -} - -bool OpenSSLCertificate::ComputeDigest(const X509* x509, - const std::string& algorithm, - unsigned char* digest, - size_t size, - size_t* length) { - const EVP_MD* md; - unsigned int n; - - if (!OpenSSLDigest::GetDigestEVP(algorithm, &md)) - return false; - - if (size < static_cast(EVP_MD_size(md))) - return false; - - X509_digest(x509, md, digest, &n); - - *length = n; - - return true; -} - -OpenSSLCertificate::~OpenSSLCertificate() { - X509_free(x509_); -} - -OpenSSLCertificate* OpenSSLCertificate::GetReference() const { - return new OpenSSLCertificate(x509_); -} - -std::string OpenSSLCertificate::ToPEMString() const { - BIO* bio = BIO_new(BIO_s_mem()); - if (!bio) { - FATAL() << "unreachable code"; - } - if (!PEM_write_bio_X509(bio, x509_)) { - BIO_free(bio); - FATAL() << "unreachable code"; - } - BIO_write(bio, "\0", 1); - char* buffer; - BIO_get_mem_data(bio, &buffer); - std::string ret(buffer); - BIO_free(bio); - return ret; -} - -void OpenSSLCertificate::ToDER(Buffer* der_buffer) const { - // In case of failure, make sure to leave the buffer empty. - der_buffer->SetSize(0); - - // Calculates the DER representation of the certificate, from scratch. - BIO* bio = BIO_new(BIO_s_mem()); - if (!bio) { - FATAL() << "unreachable code"; - } - if (!i2d_X509_bio(bio, x509_)) { - BIO_free(bio); - FATAL() << "unreachable code"; - } - char* data; - size_t length = BIO_get_mem_data(bio, &data); - der_buffer->SetData(data, length); - BIO_free(bio); -} - -void OpenSSLCertificate::AddReference() const { - RTC_DCHECK(x509_ != nullptr); - X509_up_ref(x509_); -} - -bool OpenSSLCertificate::operator==(const OpenSSLCertificate& other) const { - return X509_cmp(x509_, other.x509_) == 0; -} - -bool OpenSSLCertificate::operator!=(const OpenSSLCertificate& other) const { - return !(*this == other); -} - -// Documented in sslidentity.h. -int64_t OpenSSLCertificate::CertificateExpirationTime() const { - ASN1_TIME* expire_time = X509_get_notAfter(x509_); - bool long_format; - - if (expire_time->type == V_ASN1_UTCTIME) { - long_format = false; - } else if (expire_time->type == V_ASN1_GENERALIZEDTIME) { - long_format = true; - } else { - return -1; - } - - return ASN1TimeToSec(expire_time->data, expire_time->length, long_format); -} - OpenSSLIdentity::OpenSSLIdentity( std::unique_ptr key_pair, std::unique_ptr certificate) @@ -600,14 +323,14 @@ bool OpenSSLIdentity::ConfigureIdentity(SSL_CTX* ctx) { const OpenSSLCertificate* cert = &certificate(); if (SSL_CTX_use_certificate(ctx, cert->x509()) != 1 || SSL_CTX_use_PrivateKey(ctx, key_pair_->pkey()) != 1) { - LogSSLErrors("Configuring key and certificate"); + openssl::LogSSLErrors("Configuring key and certificate"); return false; } // If a chain is available, use it. for (size_t i = 1; i < cert_chain_->GetSize(); ++i) { cert = static_cast(&cert_chain_->Get(i)); if (SSL_CTX_add1_chain_cert(ctx, cert->x509()) != 1) { - LogSSLErrors("Configuring intermediate certificate"); + openssl::LogSSLErrors("Configuring intermediate certificate"); return false; } } diff --git a/rtc_base/opensslidentity.h b/rtc_base/opensslidentity.h index c1dc49fb58b..34044276ca7 100644 --- a/rtc_base/opensslidentity.h +++ b/rtc_base/opensslidentity.h @@ -19,6 +19,7 @@ #include "rtc_base/checks.h" #include "rtc_base/constructormagic.h" +#include "rtc_base/opensslcertificate.h" #include "rtc_base/sslidentity.h" typedef struct ssl_ctx_st SSL_CTX; @@ -56,52 +57,6 @@ class OpenSSLKeyPair { RTC_DISALLOW_COPY_AND_ASSIGN(OpenSSLKeyPair); }; -// OpenSSLCertificate encapsulates an OpenSSL X509* certificate object, -// which is also reference counted inside the OpenSSL library. -class OpenSSLCertificate : public SSLCertificate { - public: - // Caller retains ownership of the X509 object. - explicit OpenSSLCertificate(X509* x509); - - static OpenSSLCertificate* Generate(OpenSSLKeyPair* key_pair, - const SSLIdentityParams& params); - static OpenSSLCertificate* FromPEMString(const std::string& pem_string); - - ~OpenSSLCertificate() override; - - OpenSSLCertificate* GetReference() const override; - - X509* x509() const { return x509_; } - - std::string ToPEMString() const override; - void ToDER(Buffer* der_buffer) const override; - bool operator==(const OpenSSLCertificate& other) const; - bool operator!=(const OpenSSLCertificate& other) const; - - // Compute the digest of the certificate given algorithm - bool ComputeDigest(const std::string& algorithm, - unsigned char* digest, - size_t size, - size_t* length) const override; - - // Compute the digest of a certificate as an X509 * - static bool ComputeDigest(const X509* x509, - const std::string& algorithm, - unsigned char* digest, - size_t size, - size_t* length); - - bool GetSignatureDigestAlgorithm(std::string* algorithm) const override; - - int64_t CertificateExpirationTime() const override; - - private: - void AddReference() const; - - X509* x509_; - RTC_DISALLOW_COPY_AND_ASSIGN(OpenSSLCertificate); -}; - // Holds a keypair and certificate together, and a method to generate // them consistently. class OpenSSLIdentity : public SSLIdentity { diff --git a/rtc_base/opensslcommon.cc b/rtc_base/opensslutility.cc similarity index 65% rename from rtc_base/opensslcommon.cc rename to rtc_base/opensslutility.cc index 521339271ce..34ebc9ec7ef 100644 --- a/rtc_base/opensslcommon.cc +++ b/rtc_base/opensslutility.cc @@ -8,7 +8,9 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include "rtc_base/opensslcommon.h" +#include "rtc_base/opensslutility.h" + +#include #if defined(WEBRTC_POSIX) #include @@ -24,9 +26,15 @@ #include #include +#include "rtc_base/arraysize.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" +#include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/openssl.h" +#include "rtc_base/opensslcertificate.h" +#ifdef WEBRTC_ENABLE_BUILT_IN_SSL_ROOT_CERTIFICATES +#include "rtc_base/sslroots.h" +#endif // WEBRTC_ENABLE_BUILT_IN_SSL_ROOT_CERTIFICATES namespace rtc { namespace openssl { @@ -91,5 +99,37 @@ bool VerifyPeerCertMatchesHost(SSL* ssl, const std::string& host) { return is_valid_cert_name; } +void LogSSLErrors(const std::string& prefix) { + char error_buf[200]; + unsigned long err; // NOLINT + + while ((err = ERR_get_error()) != 0) { + ERR_error_string_n(err, error_buf, sizeof(error_buf)); + RTC_LOG(LS_ERROR) << prefix << ": " << error_buf << "\n"; + } +} + +#ifdef WEBRTC_ENABLE_BUILT_IN_SSL_ROOT_CERTIFICATES +bool LoadBuiltinSSLRootCertificates(SSL_CTX* ctx) { + int count_of_added_certs = 0; + for (size_t i = 0; i < arraysize(kSSLCertCertificateList); i++) { + const unsigned char* cert_buffer = kSSLCertCertificateList[i]; + size_t cert_buffer_len = kSSLCertCertificateSizeList[i]; + X509* cert = d2i_X509(nullptr, &cert_buffer, + checked_cast(cert_buffer_len)); // NOLINT + if (cert) { + int return_value = X509_STORE_add_cert(SSL_CTX_get_cert_store(ctx), cert); + if (return_value == 0) { + RTC_LOG(LS_WARNING) << "Unable to add certificate."; + } else { + count_of_added_certs++; + } + X509_free(cert); + } + } + return count_of_added_certs > 0; +} +#endif // WEBRTC_ENABLE_BUILT_IN_SSL_ROOT_CERTIFICATES + } // namespace openssl } // namespace rtc diff --git a/rtc_base/opensslcommon.h b/rtc_base/opensslutility.h similarity index 53% rename from rtc_base/opensslcommon.h rename to rtc_base/opensslutility.h index c05165b98a9..3e2d7fcab25 100644 --- a/rtc_base/opensslcommon.h +++ b/rtc_base/opensslutility.h @@ -8,12 +8,12 @@ * be found in the AUTHORS file in the root of the source tree. */ -#ifndef RTC_BASE_OPENSSLCOMMON_H_ -#define RTC_BASE_OPENSSLCOMMON_H_ +#ifndef RTC_BASE_OPENSSLUTILITY_H_ +#define RTC_BASE_OPENSSLUTILITY_H_ +#include #include - -typedef struct ssl_st SSL; +#include "rtc_base/sslcertificate.h" namespace rtc { // The openssl namespace holds static helper methods. All methods related @@ -23,7 +23,19 @@ namespace openssl { // Verifies that the hostname provided matches that in the peer certificate // attached to this SSL state. bool VerifyPeerCertMatchesHost(SSL* ssl, const std::string& host); + +// Logs all the errors in the OpenSSL errror queue from the current thread. A +// prefix can be provided for context. +void LogSSLErrors(const std::string& prefix); + +#ifndef WEBRTC_DISABLE_BUILT_IN_SSL_ROOT_CERTIFICATES +// Attempt to add the certificates from the loader into the SSL_CTX. False is +// returned only if there are no certificates returned from the loader or none +// of them can be added to the TrustStore for the provided context. +bool LoadBuiltinSSLRootCertificates(SSL_CTX* ssl_ctx); +#endif // WEBRTC_DISABLE_BUILT_IN_SSL_ROOT_CERTIFICATES + } // namespace openssl } // namespace rtc -#endif // RTC_BASE_OPENSSLCOMMON_H_ +#endif // RTC_BASE_OPENSSLUTILITY_H_ diff --git a/rtc_base/opensslcommon_unittest.cc b/rtc_base/opensslutility_unittest.cc similarity index 96% rename from rtc_base/opensslcommon_unittest.cc rename to rtc_base/opensslutility_unittest.cc index c057f5ff3fc..2f952ae4b27 100644 --- a/rtc_base/opensslcommon_unittest.cc +++ b/rtc_base/opensslutility_unittest.cc @@ -32,7 +32,7 @@ #include "rtc_base/gunit.h" #include "rtc_base/numerics/safe_conversions.h" #include "rtc_base/openssl.h" -#include "rtc_base/opensslcommon.h" +#include "rtc_base/opensslutility.h" #include "rtc_base/sslroots.h" #include "test/gmock.h" @@ -174,13 +174,14 @@ const unsigned char kFakeSSLCertificateLegacy[] = { // The server is deallocated. This client will have a peer certificate available // and is thus suitable for testing VerifyPeerCertMatchesHost. SSL* CreateSSLWithPeerCertificate(const unsigned char* cert, size_t cert_len) { - X509* x509 = d2i_X509(nullptr, &cert, checked_cast(cert_len)); + X509* x509 = + d2i_X509(nullptr, &cert, checked_cast(cert_len)); // NOLINT RTC_CHECK(x509); const unsigned char* key_ptr = kFakeSSLPrivateKey; - EVP_PKEY* key = - d2i_PrivateKey(EVP_PKEY_EC, nullptr, &key_ptr, - checked_cast(arraysize(kFakeSSLPrivateKey))); + EVP_PKEY* key = d2i_PrivateKey( + EVP_PKEY_EC, nullptr, &key_ptr, + checked_cast(arraysize(kFakeSSLPrivateKey))); // NOLINT RTC_CHECK(key); SSL_CTX* ctx = SSL_CTX_new(SSLv23_method()); @@ -225,7 +226,7 @@ SSL* CreateSSLWithPeerCertificate(const unsigned char* cert, size_t cert_len) { } } // namespace -TEST(OpenSSLCommonTest, VerifyPeerCertMatchesHostFailsOnNoPeerCertificate) { +TEST(OpenSSLUtilityTest, VerifyPeerCertMatchesHostFailsOnNoPeerCertificate) { SSL_CTX* ssl_ctx = SSL_CTX_new(DTLSv1_2_client_method()); SSL* ssl = SSL_new(ssl_ctx); @@ -235,7 +236,7 @@ TEST(OpenSSLCommonTest, VerifyPeerCertMatchesHostFailsOnNoPeerCertificate) { SSL_CTX_free(ssl_ctx); } -TEST(OpenSSLCommonTest, VerifyPeerCertMatchesHost) { +TEST(OpenSSLUtilityTest, VerifyPeerCertMatchesHost) { SSL* ssl = CreateSSLWithPeerCertificate(kFakeSSLCertificate, arraysize(kFakeSSLCertificate)); @@ -256,7 +257,7 @@ TEST(OpenSSLCommonTest, VerifyPeerCertMatchesHost) { SSL_free(ssl); } -TEST(OpenSSLCommonTest, VerifyPeerCertMatchesHostLegacy) { +TEST(OpenSSLUtilityTest, VerifyPeerCertMatchesHostLegacy) { SSL* ssl = CreateSSLWithPeerCertificate(kFakeSSLCertificateLegacy, arraysize(kFakeSSLCertificateLegacy)); diff --git a/rtc_base/ssladapter.cc b/rtc_base/ssladapter.cc index 8c62d3b592c..e091f005200 100644 --- a/rtc_base/ssladapter.cc +++ b/rtc_base/ssladapter.cc @@ -26,8 +26,8 @@ SSLAdapter* SSLAdapter::Create(AsyncSocket* socket) { /////////////////////////////////////////////////////////////////////////////// -bool InitializeSSL(VerificationCallback callback) { - return OpenSSLAdapter::InitializeSSL(callback); +bool InitializeSSL() { + return OpenSSLAdapter::InitializeSSL(); } bool CleanupSSL() { diff --git a/rtc_base/ssladapter.h b/rtc_base/ssladapter.h index 6b841543d5c..4843d264fcc 100644 --- a/rtc_base/ssladapter.h +++ b/rtc_base/ssladapter.h @@ -11,7 +11,11 @@ #ifndef RTC_BASE_SSLADAPTER_H_ #define RTC_BASE_SSLADAPTER_H_ +#include +#include + #include "rtc_base/asyncsocket.h" +#include "rtc_base/sslcertificate.h" #include "rtc_base/sslstreamadapter.h" namespace rtc { @@ -26,8 +30,13 @@ class SSLAdapter; class SSLAdapterFactory { public: virtual ~SSLAdapterFactory() {} + // Specifies whether TLS or DTLS is to be used for the SSL adapters. virtual void SetMode(SSLMode mode) = 0; + + // Specify a custom certificate verifier for SSL. + virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0; + // Creates a new SSL adapter, but from a shared context. virtual SSLAdapter* CreateAdapter(AsyncSocket* socket) = 0; @@ -54,6 +63,8 @@ class SSLAdapter : public AsyncSocketAdapter { // Do DTLS or TLS (default is TLS, if unspecified) virtual void SetMode(SSLMode mode) = 0; + // Specify a custom certificate verifier for SSL. + virtual void SetCertVerifier(SSLCertificateVerifier* ssl_cert_verifier) = 0; // Set the certificate this socket will present to incoming clients. virtual void SetIdentity(SSLIdentity* identity) = 0; @@ -82,11 +93,9 @@ class SSLAdapter : public AsyncSocketAdapter { /////////////////////////////////////////////////////////////////////////////// -typedef bool (*VerificationCallback)(void* cert); - // Call this on the main thread, before using SSL. // Call CleanupSSL when finished with SSL. -bool InitializeSSL(VerificationCallback callback = nullptr); +bool InitializeSSL(); // Call to cleanup additional threads, and also the main thread. bool CleanupSSL(); diff --git a/rtc_base/sslcertificate.cc b/rtc_base/sslcertificate.cc new file mode 100644 index 00000000000..fcbd10256f4 --- /dev/null +++ b/rtc_base/sslcertificate.cc @@ -0,0 +1,142 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "rtc_base/sslcertificate.h" + +#include +#include +#include + +#include "rtc_base/base64.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/opensslcertificate.h" +#include "rtc_base/ptr_util.h" +#include "rtc_base/sslfingerprint.h" + +namespace rtc { + +////////////////////////////////////////////////////////////////////// +// SSLCertificateStats +////////////////////////////////////////////////////////////////////// + +SSLCertificateStats::SSLCertificateStats( + std::string&& fingerprint, + std::string&& fingerprint_algorithm, + std::string&& base64_certificate, + std::unique_ptr&& issuer) + : fingerprint(std::move(fingerprint)), + fingerprint_algorithm(std::move(fingerprint_algorithm)), + base64_certificate(std::move(base64_certificate)), + issuer(std::move(issuer)) {} + +SSLCertificateStats::~SSLCertificateStats() {} + +////////////////////////////////////////////////////////////////////// +// SSLCertificate +////////////////////////////////////////////////////////////////////// + +std::unique_ptr SSLCertificate::GetStats() const { + // TODO(bemasc): Move this computation to a helper class that caches these + // values to reduce CPU use in |StatsCollector::GetStats|. This will require + // adding a fast |SSLCertificate::Equals| to detect certificate changes. + std::string digest_algorithm; + if (!GetSignatureDigestAlgorithm(&digest_algorithm)) + return nullptr; + + // |SSLFingerprint::Create| can fail if the algorithm returned by + // |SSLCertificate::GetSignatureDigestAlgorithm| is not supported by the + // implementation of |SSLCertificate::ComputeDigest|. This currently happens + // with MD5- and SHA-224-signed certificates when linked to libNSS. + std::unique_ptr ssl_fingerprint( + SSLFingerprint::Create(digest_algorithm, this)); + if (!ssl_fingerprint) + return nullptr; + std::string fingerprint = ssl_fingerprint->GetRfc4572Fingerprint(); + + Buffer der_buffer; + ToDER(&der_buffer); + std::string der_base64; + Base64::EncodeFromArray(der_buffer.data(), der_buffer.size(), &der_base64); + + return rtc::MakeUnique(std::move(fingerprint), + std::move(digest_algorithm), + std::move(der_base64), nullptr); +} + +std::unique_ptr SSLCertificate::GetUniqueReference() const { + return WrapUnique(GetReference()); +} + +////////////////////////////////////////////////////////////////////// +// SSLCertChain +////////////////////////////////////////////////////////////////////// + +SSLCertChain::SSLCertChain(std::vector> certs) + : certs_(std::move(certs)) {} + +SSLCertChain::SSLCertChain(const std::vector& certs) { + RTC_DCHECK(!certs.empty()); + certs_.resize(certs.size()); + std::transform( + certs.begin(), certs.end(), certs_.begin(), + [](const SSLCertificate* cert) -> std::unique_ptr { + return cert->GetUniqueReference(); + }); +} + +SSLCertChain::SSLCertChain(const SSLCertificate* cert) { + certs_.push_back(cert->GetUniqueReference()); +} + +SSLCertChain::SSLCertChain(SSLCertChain&& rhs) = default; + +SSLCertChain& SSLCertChain::operator=(SSLCertChain&&) = default; + +SSLCertChain::~SSLCertChain() {} + +SSLCertChain* SSLCertChain::Copy() const { + std::vector> new_certs(certs_.size()); + std::transform(certs_.begin(), certs_.end(), new_certs.begin(), + [](const std::unique_ptr& cert) + -> std::unique_ptr { + return cert->GetUniqueReference(); + }); + return new SSLCertChain(std::move(new_certs)); +} + +std::unique_ptr SSLCertChain::UniqueCopy() const { + return WrapUnique(Copy()); +} + +std::unique_ptr SSLCertChain::GetStats() const { + // We have a linked list of certificates, starting with the first element of + // |certs_| and ending with the last element of |certs_|. The "issuer" of a + // certificate is the next certificate in the chain. Stats are produced for + // each certificate in the list. Here, the "issuer" is the issuer's stats. + std::unique_ptr issuer; + // The loop runs in reverse so that the |issuer| is known before the + // certificate issued by |issuer|. + for (ptrdiff_t i = certs_.size() - 1; i >= 0; --i) { + std::unique_ptr new_stats = certs_[i]->GetStats(); + if (new_stats) { + new_stats->issuer = std::move(issuer); + } + issuer = std::move(new_stats); + } + return issuer; +} + +// static +SSLCertificate* SSLCertificate::FromPEMString(const std::string& pem_string) { + return OpenSSLCertificate::FromPEMString(pem_string); +} + +} // namespace rtc diff --git a/rtc_base/sslcertificate.h b/rtc_base/sslcertificate.h new file mode 100644 index 00000000000..29c4db58ef5 --- /dev/null +++ b/rtc_base/sslcertificate.h @@ -0,0 +1,147 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// Generic interface for SSL Certificates, used in both the SSLAdapter +// for TLS TURN connections and the SSLStreamAdapter for DTLS Peer to Peer +// Connections for SRTP Key negotiation and SCTP encryption. + +#ifndef RTC_BASE_SSLCERTIFICATE_H_ +#define RTC_BASE_SSLCERTIFICATE_H_ + +#include +#include +#include +#include + +#include "rtc_base/buffer.h" +#include "rtc_base/constructormagic.h" +#include "rtc_base/messagedigest.h" +#include "rtc_base/timeutils.h" + +namespace rtc { + +struct SSLCertificateStats { + SSLCertificateStats(std::string&& fingerprint, + std::string&& fingerprint_algorithm, + std::string&& base64_certificate, + std::unique_ptr&& issuer); + ~SSLCertificateStats(); + std::string fingerprint; + std::string fingerprint_algorithm; + std::string base64_certificate; + std::unique_ptr issuer; +}; + +// Abstract interface overridden by SSL library specific +// implementations. + +// A somewhat opaque type used to encapsulate a certificate. +// Wraps the SSL library's notion of a certificate, with reference counting. +// The SSLCertificate object is pretty much immutable once created. +// (The OpenSSL implementation only does reference counting and +// possibly caching of intermediate results.) +class SSLCertificate { + public: + // Parses and builds a certificate from a PEM encoded string. + // Returns null on failure. + // The length of the string representation of the certificate is + // stored in *pem_length if it is non-null, and only if + // parsing was successful. + // Caller is responsible for freeing the returned object. + static SSLCertificate* FromPEMString(const std::string& pem_string); + virtual ~SSLCertificate() {} + + // Returns a new SSLCertificate object instance wrapping the same + // underlying certificate, including its chain if present. Caller is + // responsible for freeing the returned object. Use GetUniqueReference + // instead. + virtual SSLCertificate* GetReference() const = 0; + + std::unique_ptr GetUniqueReference() const; + + // Returns a PEM encoded string representation of the certificate. + virtual std::string ToPEMString() const = 0; + + // Provides a DER encoded binary representation of the certificate. + virtual void ToDER(Buffer* der_buffer) const = 0; + + // Gets the name of the digest algorithm that was used to compute this + // certificate's signature. + virtual bool GetSignatureDigestAlgorithm(std::string* algorithm) const = 0; + + // Compute the digest of the certificate given algorithm + virtual bool ComputeDigest(const std::string& algorithm, + unsigned char* digest, + size_t size, + size_t* length) const = 0; + + // Returns the time in seconds relative to epoch, 1970-01-01T00:00:00Z (UTC), + // or -1 if an expiration time could not be retrieved. + virtual int64_t CertificateExpirationTime() const = 0; + + // Gets information (fingerprint, etc.) about this certificate. This is used + // for certificate stats, see + // https://w3c.github.io/webrtc-stats/#certificatestats-dict*. + std::unique_ptr GetStats() const; +}; + +// SSLCertChain is a simple wrapper for a vector of SSLCertificates. It serves +// primarily to ensure proper memory management (especially deletion) of the +// SSLCertificate pointers. +class SSLCertChain { + public: + explicit SSLCertChain(std::vector> certs); + // These constructors copy the provided SSLCertificate(s), so the caller + // retains ownership. + explicit SSLCertChain(const std::vector& certs); + explicit SSLCertChain(const SSLCertificate* cert); + // Allow move semantics for the object. + SSLCertChain(SSLCertChain&&); + SSLCertChain& operator=(SSLCertChain&&); + + ~SSLCertChain(); + + // Vector access methods. + size_t GetSize() const { return certs_.size(); } + + // Returns a temporary reference, only valid until the chain is destroyed. + const SSLCertificate& Get(size_t pos) const { return *(certs_[pos]); } + + // Returns a new SSLCertChain object instance wrapping the same underlying + // certificate chain. Caller is responsible for freeing the returned object. + SSLCertChain* Copy() const; + // Same as above, but returning a unique_ptr for convenience. + std::unique_ptr UniqueCopy() const; + + // Gets information (fingerprint, etc.) about this certificate chain. This is + // used for certificate stats, see + // https://w3c.github.io/webrtc-stats/#certificatestats-dict*. + std::unique_ptr GetStats() const; + + private: + std::vector> certs_; + + RTC_DISALLOW_COPY_AND_ASSIGN(SSLCertChain); +}; + +// SSLCertificateVerifier provides a simple interface to allow third parties to +// define their own certificate verification code. It is completely independent +// from the underlying SSL implementation. +class SSLCertificateVerifier { + public: + virtual ~SSLCertificateVerifier() = default; + // Returns true if the certificate is valid, else false. It is up to the + // implementer to define what a valid certificate looks like. + virtual bool Verify(const SSLCertificate& certificate) = 0; +}; + +} // namespace rtc + +#endif // RTC_BASE_SSLCERTIFICATE_H_ diff --git a/rtc_base/sslidentity.cc b/rtc_base/sslidentity.cc index 1514e52be1f..0fab22363df 100644 --- a/rtc_base/sslidentity.cc +++ b/rtc_base/sslidentity.cc @@ -24,56 +24,14 @@ namespace rtc { +////////////////////////////////////////////////////////////////////// +// KeyParams +////////////////////////////////////////////////////////////////////// + const char kPemTypeCertificate[] = "CERTIFICATE"; const char kPemTypeRsaPrivateKey[] = "RSA PRIVATE KEY"; const char kPemTypeEcPrivateKey[] = "EC PRIVATE KEY"; -SSLCertificateStats::SSLCertificateStats( - std::string&& fingerprint, - std::string&& fingerprint_algorithm, - std::string&& base64_certificate, - std::unique_ptr&& issuer) - : fingerprint(std::move(fingerprint)), - fingerprint_algorithm(std::move(fingerprint_algorithm)), - base64_certificate(std::move(base64_certificate)), - issuer(std::move(issuer)) { -} - -SSLCertificateStats::~SSLCertificateStats() { -} - -std::unique_ptr SSLCertificate::GetStats() const { - // TODO(bemasc): Move this computation to a helper class that caches these - // values to reduce CPU use in |StatsCollector::GetStats|. This will require - // adding a fast |SSLCertificate::Equals| to detect certificate changes. - std::string digest_algorithm; - if (!GetSignatureDigestAlgorithm(&digest_algorithm)) - return nullptr; - - // |SSLFingerprint::Create| can fail if the algorithm returned by - // |SSLCertificate::GetSignatureDigestAlgorithm| is not supported by the - // implementation of |SSLCertificate::ComputeDigest|. This currently happens - // with MD5- and SHA-224-signed certificates when linked to libNSS. - std::unique_ptr ssl_fingerprint( - SSLFingerprint::Create(digest_algorithm, this)); - if (!ssl_fingerprint) - return nullptr; - std::string fingerprint = ssl_fingerprint->GetRfc4572Fingerprint(); - - Buffer der_buffer; - ToDER(&der_buffer); - std::string der_base64; - Base64::EncodeFromArray(der_buffer.data(), der_buffer.size(), &der_base64); - - return rtc::MakeUnique(std::move(fingerprint), - std::move(digest_algorithm), - std::move(der_base64), nullptr); -} - -std::unique_ptr SSLCertificate::GetUniqueReference() const { - return WrapUnique(GetReference()); -} - KeyParams::KeyParams(KeyType key_type) { if (key_type == KT_ECDSA) { type_ = KT_ECDSA; @@ -127,6 +85,10 @@ KeyType IntKeyTypeFamilyToKeyType(int key_type_family) { return static_cast(key_type_family); } +////////////////////////////////////////////////////////////////////// +// SSLIdentity +////////////////////////////////////////////////////////////////////// + bool SSLIdentity::PemToDer(const std::string& pem_type, const std::string& pem_string, std::string* der) { @@ -177,62 +139,6 @@ std::string SSLIdentity::DerToPem(const std::string& pem_type, return result.str(); } -SSLCertChain::SSLCertChain(std::vector> certs) - : certs_(std::move(certs)) {} - -SSLCertChain::SSLCertChain(const std::vector& certs) { - RTC_DCHECK(!certs.empty()); - certs_.resize(certs.size()); - std::transform( - certs.begin(), certs.end(), certs_.begin(), - [](const SSLCertificate* cert) -> std::unique_ptr { - return cert->GetUniqueReference(); - }); -} - -SSLCertChain::SSLCertChain(const SSLCertificate* cert) { - certs_.push_back(cert->GetUniqueReference()); -} - -SSLCertChain::~SSLCertChain() {} - -SSLCertChain* SSLCertChain::Copy() const { - std::vector> new_certs(certs_.size()); - std::transform(certs_.begin(), certs_.end(), new_certs.begin(), - [](const std::unique_ptr& cert) - -> std::unique_ptr { - return cert->GetUniqueReference(); - }); - return new SSLCertChain(std::move(new_certs)); -} - -std::unique_ptr SSLCertChain::UniqueCopy() const { - return WrapUnique(Copy()); -} - -std::unique_ptr SSLCertChain::GetStats() const { - // We have a linked list of certificates, starting with the first element of - // |certs_| and ending with the last element of |certs_|. The "issuer" of a - // certificate is the next certificate in the chain. Stats are produced for - // each certificate in the list. Here, the "issuer" is the issuer's stats. - std::unique_ptr issuer; - // The loop runs in reverse so that the |issuer| is known before the - // certificate issued by |issuer|. - for (ptrdiff_t i = certs_.size() - 1; i >= 0; --i) { - std::unique_ptr new_stats = certs_[i]->GetStats(); - if (new_stats) { - new_stats->issuer = std::move(issuer); - } - issuer = std::move(new_stats); - } - return issuer; -} - -// static -SSLCertificate* SSLCertificate::FromPEMString(const std::string& pem_string) { - return OpenSSLCertificate::FromPEMString(pem_string); -} - // static SSLIdentity* SSLIdentity::GenerateWithExpiration(const std::string& common_name, const KeyParams& key_params, @@ -280,6 +186,10 @@ bool operator!=(const SSLIdentity& a, const SSLIdentity& b) { return !(a == b); } +////////////////////////////////////////////////////////////////////// +// Helper Functions +////////////////////////////////////////////////////////////////////// + // Read |n| bytes from ASN1 number string at *|pp| and return the numeric value. // Update *|pp| and *|np| to reflect number of read bytes. static inline int ASN1ReadInt(const unsigned char** pp, size_t* np, size_t n) { diff --git a/rtc_base/sslidentity.h b/rtc_base/sslidentity.h index d14610b0f9a..1379d733be5 100644 --- a/rtc_base/sslidentity.h +++ b/rtc_base/sslidentity.h @@ -21,113 +21,11 @@ #include "rtc_base/buffer.h" #include "rtc_base/constructormagic.h" #include "rtc_base/messagedigest.h" +#include "rtc_base/sslcertificate.h" #include "rtc_base/timeutils.h" namespace rtc { -// Forward declaration due to circular dependency with SSLCertificate. -class SSLCertChain; - -struct SSLCertificateStats { - SSLCertificateStats(std::string&& fingerprint, - std::string&& fingerprint_algorithm, - std::string&& base64_certificate, - std::unique_ptr&& issuer); - ~SSLCertificateStats(); - std::string fingerprint; - std::string fingerprint_algorithm; - std::string base64_certificate; - std::unique_ptr issuer; -}; - -// Abstract interface overridden by SSL library specific -// implementations. - -// A somewhat opaque type used to encapsulate a certificate. -// Wraps the SSL library's notion of a certificate, with reference counting. -// The SSLCertificate object is pretty much immutable once created. -// (The OpenSSL implementation only does reference counting and -// possibly caching of intermediate results.) -class SSLCertificate { - public: - // Parses and builds a certificate from a PEM encoded string. - // Returns null on failure. - // The length of the string representation of the certificate is - // stored in *pem_length if it is non-null, and only if - // parsing was successful. - // Caller is responsible for freeing the returned object. - static SSLCertificate* FromPEMString(const std::string& pem_string); - virtual ~SSLCertificate() {} - - // Returns a new SSLCertificate object instance wrapping the same - // underlying certificate, including its chain if present. Caller is - // responsible for freeing the returned object. Use GetUniqueReference - // instead. - virtual SSLCertificate* GetReference() const = 0; - - std::unique_ptr GetUniqueReference() const; - - // Returns a PEM encoded string representation of the certificate. - virtual std::string ToPEMString() const = 0; - - // Provides a DER encoded binary representation of the certificate. - virtual void ToDER(Buffer* der_buffer) const = 0; - - // Gets the name of the digest algorithm that was used to compute this - // certificate's signature. - virtual bool GetSignatureDigestAlgorithm(std::string* algorithm) const = 0; - - // Compute the digest of the certificate given algorithm - virtual bool ComputeDigest(const std::string& algorithm, - unsigned char* digest, - size_t size, - size_t* length) const = 0; - - // Returns the time in seconds relative to epoch, 1970-01-01T00:00:00Z (UTC), - // or -1 if an expiration time could not be retrieved. - virtual int64_t CertificateExpirationTime() const = 0; - - // Gets information (fingerprint, etc.) about this certificate. This is used - // for certificate stats, see - // https://w3c.github.io/webrtc-stats/#certificatestats-dict*. - std::unique_ptr GetStats() const; -}; - -// SSLCertChain is a simple wrapper for a vector of SSLCertificates. It serves -// primarily to ensure proper memory management (especially deletion) of the -// SSLCertificate pointers. -class SSLCertChain { - public: - explicit SSLCertChain(std::vector> certs); - // These constructors copy the provided SSLCertificate(s), so the caller - // retains ownership. - explicit SSLCertChain(const std::vector& certs); - explicit SSLCertChain(const SSLCertificate* cert); - ~SSLCertChain(); - - // Vector access methods. - size_t GetSize() const { return certs_.size(); } - - // Returns a temporary reference, only valid until the chain is destroyed. - const SSLCertificate& Get(size_t pos) const { return *(certs_[pos]); } - - // Returns a new SSLCertChain object instance wrapping the same underlying - // certificate chain. Caller is responsible for freeing the returned object. - SSLCertChain* Copy() const; - // Same as above, but returning a unique_ptr for convenience. - std::unique_ptr UniqueCopy() const; - - // Gets information (fingerprint, etc.) about this certificate chain. This is - // used for certificate stats, see - // https://w3c.github.io/webrtc-stats/#certificatestats-dict*. - std::unique_ptr GetStats() const; - - private: - std::vector> certs_; - - RTC_DISALLOW_COPY_AND_ASSIGN(SSLCertChain); -}; - // KT_LAST is intended for vector declarations and loops over all key types; // it does not represent any key type in itself. // KT_DEFAULT is used as the default KeyType for KeyParams. diff --git a/rtc_base/testcertificateverifier.h b/rtc_base/testcertificateverifier.h new file mode 100644 index 00000000000..8ad6e4d736c --- /dev/null +++ b/rtc_base/testcertificateverifier.h @@ -0,0 +1,34 @@ +/* + * Copyright 2018 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef RTC_BASE_TESTCERTIFICATEVERIFIER_H_ +#define RTC_BASE_TESTCERTIFICATEVERIFIER_H_ + +#include "rtc_base/sslcertificate.h" + +namespace rtc { + +class TestCertificateVerifier : public SSLCertificateVerifier { + public: + TestCertificateVerifier() = default; + ~TestCertificateVerifier() override = default; + + bool Verify(const SSLCertificate& certificate) override { + call_count_++; + return verify_certificate_; + } + + size_t call_count_ = 0; + bool verify_certificate_ = true; +}; + +} // namespace rtc + +#endif // RTC_BASE_TESTCERTIFICATEVERIFIER_H_ diff --git a/webrtc.gni b/webrtc.gni index 35a65e76f3f..95c24d4968d 100644 --- a/webrtc.gni +++ b/webrtc.gni @@ -31,6 +31,11 @@ if (is_ios) { } declare_args() { + # Setting this to false will require the API user to pass in their own + # SSLCertificateVerifier to verify the certificates presented from a + # TLS-TURN server. In return disabling this saves around 100kb in the binary. + rtc_builtin_ssl_root_certificates = true + # Include the iLBC audio codec? rtc_include_ilbc = true