Skip to content

Commit

Permalink
GEODE-9991: Refactor for consistency and add tests. (apache#7533)
Browse files Browse the repository at this point in the history
* Combine common configuration into method for consistency.
* Adds tests for new extracted methods.
  • Loading branch information
jake-at-work authored Apr 6, 2022
1 parent 30bd1ce commit 75ea5f7
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,21 +192,54 @@ public SSLParameterExtension getSSLParameterExtension() {
}

/**
* Returns true if ciphers is either null, empty or is set to "any" (ignoring case)
* Checks if "any" cipher is specified in {@link #getCiphers()}
*
* @return {@code true} if ciphers is either {@code null}, empty or is set to "any"
* (ignoring case), otherwise {@code false}.
*/
public boolean isAnyCiphers() {
return isAnyCiphers(ciphers);
}

/**
* Checks if "any" cipher is specified in {@code ciphers}.
*
* @param ciphers Comma or space separated list of cipher names.
* @return {@code true} if {@code ciphers} is either {@code null}, empty or is set to "any"
* (ignoring case), otherwise {@code false}.
*/
public static boolean isAnyCiphers(final String ciphers) {
return StringUtils.isBlank(ciphers) || "any".equalsIgnoreCase(ciphers);
}

/**
* Returns true if protocols is either null, empty or is set to "any" (ignoring case)
* Checks if "any" cipher is specified in {@code ciphers}.
*
* @param ciphers Array of cipher names.
* @return {@code true} if {@code ciphers} is either {@code null}, empty or first entry is "any"
* (ignoring case), otherwise {@code false}.
*/
public static boolean isAnyCiphers(final String... ciphers) {
return ArrayUtils.isEmpty(ciphers) || "any".equalsIgnoreCase(ciphers[0]);
}

/**
* Checks if "any" protocol is specified in {@code protocols}.
*
* @param protocols Comma or space separated list of protocol names.
* @return {@code true} if {@code protocols} is either {@code null}, empty or is set to "any"
* (ignoring case), otherwise {@code false}.
*/
public static boolean isAnyProtocols(final String protocols) {
return StringUtils.isBlank(protocols) || "any".equalsIgnoreCase(protocols);
}

/**
* Returns true if protocols is either null, empty or is set to "any" (ignoring case)
* Checks if "any" protocol is specified in {@code protocols}.
*
* @param protocols Array of protocol names.
* @return {@code true} if {@code protocols} is either {@code null}, empty or first entry is "any"
* (ignoring case), otherwise {@code false}.
*/
public static boolean isAnyProtocols(final String... protocols) {
return ArrayUtils.isEmpty(protocols) || "any".equalsIgnoreCase(protocols[0]);
Expand Down Expand Up @@ -394,7 +427,6 @@ public Builder setSSLParameterExtension(
SSLParameterExtension sslParameterExtension =
CallbackInstantiator.getObjectOfTypeFromClassName(sslParameterExtensionConfig,
SSLParameterExtension.class);
ids.getConfig().getDistributedSystemId();

sslParameterExtension.init(
new SSLParameterExtensionContextImpl(ids.getConfig().getDistributedSystemId()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package org.apache.geode.internal.net;

package org.apache.geode.internal.net;

import static org.apache.commons.lang3.ObjectUtils.getIfNull;
import static org.apache.geode.internal.net.filewatch.FileWatchingX509ExtendedKeyManager.newFileWatchingKeyManager;
import static org.apache.geode.internal.net.filewatch.FileWatchingX509ExtendedTrustManager.newFileWatchingTrustManager;

Expand Down Expand Up @@ -76,7 +77,6 @@
import org.apache.geode.net.SSLParameterExtension;
import org.apache.geode.util.internal.GeodeGlossary;


/**
* SocketCreators are built using a SocketCreatorFactory using Geode distributed-system properties.
* They know how to properly configure sockets for TLS (SSL) communications and perform
Expand Down Expand Up @@ -115,17 +115,16 @@ public class SocketCreator extends TcpSocketCreatorImpl {
* Only print this SocketCreator's config once
*/
private boolean configShown = false;

/**
* Only print hostname validation disabled log once
*/
private boolean hostnameValidationDisabledLogShown = false;


private SSLContext sslContext;

private final SSLConfig sslConfig;


private ClientSocketFactory clientSocketFactory;

/**
Expand All @@ -136,10 +135,6 @@ public class SocketCreator extends TcpSocketCreatorImpl {
public static final boolean ENABLE_TCP_KEEP_ALIVE =
AdvancedSocketCreatorImpl.ENABLE_TCP_KEEP_ALIVE;

// -------------------------------------------------------------------------
// Static instance accessors
// -------------------------------------------------------------------------

/**
* This method has migrated to LocalHostUtil but is kept in place here for
* backward-compatibility testing.
Expand All @@ -152,16 +147,15 @@ public static InetAddress getLocalHost() throws UnknownHostException {
return LocalHostUtil.getLocalHost();
}


/**
* returns the host name for the given inet address, using a local cache of names to avoid dns
* hits and duplicate strings
*/
public static String getHostName(InetAddress addr) {
String result = hostNames.get(addr);
public static String getHostName(InetAddress address) {
String result = hostNames.get(address);
if (result == null) {
result = addr.getHostName();
hostNames.put(addr, result);
result = address.getHostName();
hostNames.put(address, result);
}
return result;
}
Expand All @@ -173,11 +167,6 @@ public static void resetHostNameCache() {
hostNames.clear();
}


// -------------------------------------------------------------------------
// Constructor
// -------------------------------------------------------------------------

/**
* Constructs new SocketCreator instance.
*/
Expand All @@ -194,15 +183,11 @@ public SocketCreator(final SSLConfig sslConfig) {

/** returns the hostname or address for this client */
public static String getClientHostName() throws UnknownHostException {
InetAddress hostAddr = LocalHostUtil.getLocalHost();
return SocketCreator.use_client_host_name ? hostAddr.getCanonicalHostName()
: hostAddr.getHostAddress();
InetAddress address = LocalHostUtil.getLocalHost();
return SocketCreator.use_client_host_name ? address.getCanonicalHostName()
: address.getHostAddress();
}

// -------------------------------------------------------------------------
// Initializers (change SocketCreator state)
// -------------------------------------------------------------------------

protected void initializeCreators() {
clusterSocketCreator = new SCClusterSocketCreator(this);
clientSocketCreator = new SCClientSocketCreator(this);
Expand Down Expand Up @@ -348,10 +333,6 @@ protected boolean useSSL() {
return sslConfig.isEnabled();
}

// -------------------------------------------------------------------------
// Public methods
// -------------------------------------------------------------------------

/**
* Returns an SSLEngine that can be used to perform TLS handshakes and communication
*/
Expand All @@ -375,7 +356,7 @@ void configureSSLParameters(final SSLParameters parameters, final String hostNam
final int port, final boolean clientSocket) {
if (sslConfig.doEndpointIdentification()) {
// set server-names so that endpoint identification algorithms can find what's expected
setServerNames(parameters, new HostAndPort(hostName, port));
addServerNameIfNotSet(parameters, new HostAndPort(hostName, port));
}

if (clientSocket) {
Expand All @@ -384,16 +365,9 @@ void configureSSLParameters(final SSLParameters parameters, final String hostNam
parameters.setNeedClientAuth(sslConfig.isRequireAuth());
}

final String[] protocols = clientSocket ? sslConfig.getClientProtocolsAsStringArray()
: sslConfig.getServerProtocolsAsStringArray();
if (!SSLConfig.isAnyProtocols(protocols)) {
parameters.setProtocols(protocols);
}
configureProtocols(clientSocket, parameters);

final String[] ciphers = sslConfig.getCiphersAsStringArray();
if (ciphers != null && !"any".equalsIgnoreCase(ciphers[0])) {
parameters.setCipherSuites(ciphers);
}
configureCipherSuites(parameters);
}

/**
Expand All @@ -404,7 +378,7 @@ void configureSSLParameters(final SSLParameters parameters, final String hostNam
* @param socketChannel the socket's NIO channel
* @param engine the sslEngine (see createSSLEngine)
* @param timeout handshake timeout in milliseconds. No timeout if <= 0
* @param peerNetBuffer the buffer to use in reading data fron socketChannel. This should also be
* @param peerNetBuffer the buffer to use in reading data from socketChannel. This should also be
* used in subsequent I/O operations
* @return The SSLEngine to be used in processing data for sending/receiving from the channel
*/
Expand Down Expand Up @@ -452,21 +426,17 @@ public NioSslEngine handshakeSSLSocketChannel(SocketChannel socketChannel,
return nioSslEngine;
}

/**
* @return true if the parameters have been modified by this method
*/
private boolean checkAndEnableHostnameValidation(SSLParameters sslParameters) {
void checkAndEnableHostnameValidation(final SSLParameters sslParameters) {
if (sslConfig.doEndpointIdentification()) {
sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
return true;
return;
}
if (!hostnameValidationDisabledLogShown) {
logger.info("Your SSL configuration disables hostname validation. "
+ "ssl-endpoint-identification-enabled should be set to true when SSL is enabled. "
+ "Please refer to the Apache GEODE SSL Documentation for SSL Property: ssl‑endpoint‑identification‑enabled");
hostnameValidationDisabledLogShown = true;
}
return false;
}

/**
Expand Down Expand Up @@ -514,11 +484,11 @@ void handshakeIfSocketIsSSL(Socket socket, int timeout) throws IOException {
* client/server/advanced interfaces because it references WAN classes that aren't
* available to them.
*/
public ServerSocket createServerSocket(int nport, int backlog, InetAddress bindAddr,
public ServerSocket createServerSocket(int port, int backlog, InetAddress bindAddress,
List<GatewayTransportFilter> transportFilters, int socketBufferSize) throws IOException {
if (transportFilters.isEmpty()) {
return ((SCClusterSocketCreator) forCluster())
.createServerSocket(nport, backlog, bindAddr, socketBufferSize, useSSL());
.createServerSocket(port, backlog, bindAddress, socketBufferSize, useSSL());
} else {
printConfig();
ServerSocket result = new TransportFilterServerSocket(transportFilters);
Expand All @@ -528,61 +498,35 @@ public ServerSocket createServerSocket(int nport, int backlog, InetAddress bindA
// java.net.ServerSocket.setReceiverBufferSize javadocs)
result.setReceiveBufferSize(socketBufferSize);
try {
result.bind(new InetSocketAddress(bindAddr, nport), backlog);
result.bind(new InetSocketAddress(bindAddress, port), backlog);
} catch (BindException e) {
BindException throwMe = new BindException(
String.format("Failed to create server socket on %s[%s]", bindAddr, nport));
String.format("Failed to create server socket on %s[%s]", bindAddress, port));
throwMe.initCause(e);
throw throwMe;
}
return result;
}
}


// -------------------------------------------------------------------------
// Private implementation methods
// -------------------------------------------------------------------------


/**
* When a socket is connected to a server socket, it should be passed to this method for SSL
* configuration.
*/
void configureClientSSLSocket(Socket socket, HostAndPort addr, int timeout) throws IOException {
void configureClientSSLSocket(final Socket socket, final HostAndPort address, int timeout)
throws IOException {
if (socket instanceof SSLSocket) {
SSLSocket sslSocket = (SSLSocket) socket;
final SSLSocket sslSocket = (SSLSocket) socket;

sslSocket.setUseClientMode(true);
sslSocket.setEnableSessionCreation(true);

SSLParameters parameters = sslSocket.getSSLParameters();
boolean updateSSLParameters =
checkAndEnableHostnameValidation(parameters);

if (setServerNames(parameters, addr)) {
updateSSLParameters = true;
}

SSLParameterExtension sslParameterExtension = sslConfig.getSSLParameterExtension();
if (sslParameterExtension != null) {
parameters =
sslParameterExtension.modifySSLClientSocketParameters(parameters);
updateSSLParameters = true;
}

if (updateSSLParameters) {
sslSocket.setSSLParameters(parameters);
}

String[] protocols = sslConfig.getClientProtocolsAsStringArray();
if (!SSLConfig.isAnyProtocols(protocols)) {
sslSocket.setEnabledProtocols(protocols);
}
String[] ciphers = sslConfig.getCiphersAsStringArray();
if (ciphers != null && !"any".equalsIgnoreCase(ciphers[0])) {
sslSocket.setEnabledCipherSuites(ciphers);
}
final SSLParameters parameters = sslSocket.getSSLParameters();
checkAndEnableHostnameValidation(parameters);
addServerNameIfNotSet(parameters, address);
configureProtocols(true, parameters);
configureCipherSuites(parameters);
sslSocket.setSSLParameters(applySSLParameterExtensions(parameters));

try {
if (timeout > 0) {
Expand Down Expand Up @@ -615,25 +559,43 @@ void configureClientSSLSocket(Socket socket, HostAndPort addr, int timeout) thro
}
}

/**
* returns true if the SSLParameters are altered, false if not
*/
private boolean setServerNames(SSLParameters modifiedParams, HostAndPort addr) {
List<SNIServerName> oldNames = modifiedParams.getServerNames();
oldNames = oldNames == null ? Collections.emptyList() : oldNames;
final List<SNIServerName> serverNames = new ArrayList<>(oldNames);
SSLParameters applySSLParameterExtensions(final SSLParameters parameters) {
final SSLParameterExtension sslParameterExtension = sslConfig.getSSLParameterExtension();
if (sslParameterExtension != null) {
return sslParameterExtension.modifySSLClientSocketParameters(parameters);
}
return parameters;
}

void configureProtocols(final boolean clientSocket, final SSLParameters parameters) {
final String[] protocols = clientSocket ? sslConfig.getClientProtocolsAsStringArray()
: sslConfig.getServerProtocolsAsStringArray();
if (!SSLConfig.isAnyProtocols(protocols)) {
parameters.setProtocols(protocols);
}
}

void configureCipherSuites(final SSLParameters parameters) {
final String[] ciphers = sslConfig.getCiphersAsStringArray();
if (!SSLConfig.isAnyCiphers(ciphers)) {
parameters.setCipherSuites(ciphers);
}
}

static void addServerNameIfNotSet(final SSLParameters parameters,
final HostAndPort address) {
final List<SNIServerName> serverNames =
new ArrayList<>(getIfNull(parameters.getServerNames(), Collections::emptyList));

if (serverNames.stream()
.mapToInt(SNIServerName::getType)
.anyMatch(type -> type == StandardConstants.SNI_HOST_NAME)) {
// we already have a SNI hostname set. Do nothing.
return false;
return;
}

String hostName = addr.getHostName();
serverNames.add(new SNIHostName(hostName));
modifiedParams.setServerNames(serverNames);
return true;
serverNames.add(new SNIHostName(address.getHostName()));
parameters.setServerNames(serverNames);
}

/**
Expand Down Expand Up @@ -664,7 +626,7 @@ protected void initializeClientSocketFactory() {
if (className != null) {
Object o;
try {
Class c = ClassPathLoader.getLatest().forName(className);
Class<?> c = ClassPathLoader.getLatest().forName(className);
o = c.newInstance();
} catch (Exception e) {
// No cache exists yet, so this can't be logged.
Expand Down
Loading

0 comments on commit 75ea5f7

Please sign in to comment.