Skip to content

Commit

Permalink
Add SubProtocolCapable interface
Browse files Browse the repository at this point in the history
The addition of SubProtocolCapable simplifies configuration since it is
no longer necessary to explicitly configure DefaultHandshakeHandler
with a list of supported sub-protocols. We will not also check if the
WebSocketHandler to use for the WebSocket request is an instance of
SubProtocolCapable and obtain the list of sub-protocols that way. The
provided SubProtocolWebSocketHandler does implement this interface.

Issue: SPR-11111
  • Loading branch information
rstoyanchev committed Nov 25, 2013
1 parent 59002f2 commit 4e82416
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@

package org.springframework.web.socket.messaging;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.logging.Log;
Expand All @@ -37,6 +32,7 @@
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.support.SubProtocolCapable;


/**
Expand All @@ -55,7 +51,7 @@
*
* @since 4.0
*/
public class SubProtocolWebSocketHandler implements WebSocketHandler, MessageHandler {
public class SubProtocolWebSocketHandler implements SubProtocolCapable, WebSocketHandler, MessageHandler {

private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class);

Expand Down Expand Up @@ -136,8 +132,8 @@ public SubProtocolHandler getDefaultProtocolHandler() {
/**
* Return all supported protocols.
*/
public Set<String> getSupportedProtocols() {
return this.protocolHandlers.keySet();
public List<String> getSubProtocols() {
return new ArrayList<String>(this.protocolHandlers.keySet());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ private static SubProtocolWebSocketHandler unwrapSubProtocolWebSocketHandler(Web
public StompWebSocketEndpointRegistration addEndpoint(String... paths) {

this.subProtocolWebSocketHandler.addProtocolHandler(this.stompHandler);
Set<String> subProtocols = this.subProtocolWebSocketHandler.getSupportedProtocols();

WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(
paths, this.webSocketHandler, subProtocols, this.sockJsScheduler);
paths, this.webSocketHandler, this.sockJsScheduler);
this.registrations.add(registration);

return registration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE

private final WebSocketHandler webSocketHandler;

private final String[] subProtocols;

private final TaskScheduler sockJsTaskScheduler;

private HandshakeHandler handshakeHandler;
Expand All @@ -56,28 +54,14 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE


public WebMvcStompWebSocketEndpointRegistration(String[] paths, WebSocketHandler webSocketHandler,
Set<String> subProtocols, TaskScheduler sockJsTaskScheduler) {
TaskScheduler sockJsTaskScheduler) {

Assert.notEmpty(paths, "No paths specified");
Assert.notNull(webSocketHandler, "'webSocketHandler' is required");
Assert.notNull(subProtocols, "'subProtocols' is required");

this.paths = paths;
this.webSocketHandler = webSocketHandler;
this.subProtocols = subProtocols.toArray(new String[subProtocols.size()]);
this.sockJsTaskScheduler = sockJsTaskScheduler;

this.handshakeHandler = new DefaultHandshakeHandler();
updateHandshakeHandler();
}

private void updateHandshakeHandler() {
if (handshakeHandler instanceof DefaultHandshakeHandler) {
DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handshakeHandler;
if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) {
defaultHandshakeHandler.setSupportedProtocols(this.subProtocols);
}
}
}

/**
Expand All @@ -87,7 +71,6 @@ private void updateHandshakeHandler() {
public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler handshakeHandler) {
Assert.notNull(handshakeHandler, "'handshakeHandler' must not be null");
this.handshakeHandler = handshakeHandler;
updateHandshakeHandler();
return this;
}

Expand All @@ -97,8 +80,10 @@ public StompWebSocketEndpointRegistration setHandshakeHandler(HandshakeHandler h
@Override
public SockJsServiceRegistration withSockJS() {
this.registration = new StompSockJsServiceRegistration(this.sockJsTaskScheduler);
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
this.registration.setTransportHandlerOverrides(transportHandler);
if (this.handshakeHandler != null) {
WebSocketTransportHandler transportHandler = new WebSocketTransportHandler(this.handshakeHandler);
this.registration.setTransportHandlerOverrides(transportHandler);
}
return this.registration;
}

Expand All @@ -114,8 +99,9 @@ protected final MultiValueMap<HttpRequestHandler, String> getMappings() {
}
else {
for (String path : this.paths) {
WebSocketHttpRequestHandler handler =
new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler);
WebSocketHttpRequestHandler handler = (this.handshakeHandler != null) ?
new WebSocketHttpRequestHandler(this.webSocketHandler, this.handshakeHandler) :
new WebSocketHttpRequestHandler(this.webSocketHandler);
mappings.add(handler, path);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.support.SubProtocolCapable;
import org.springframework.web.socket.support.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.support.WebSocketHandlerDecorator;
import org.springframework.web.socket.support.WebSocketHttpHeaders;

/**
Expand Down Expand Up @@ -122,10 +124,16 @@ public DefaultHandshakeHandler(RequestUpgradeStrategy upgradeStrategy) {
}

/**
* Use this property to configure a list of sub-protocols that are supported.
* The first protocol that matches what the client requested is selected.
* If no protocol matches or this property is not configured, then the
* response will not contain a Sec-WebSocket-Protocol header.
* Use this property to configure the list of supported sub-protocols.
* The first configured sub-protocol that matches a client-requested sub-protocol
* is accepted. If there are no matches the response will not contain a
* {@literal Sec-WebSocket-Protocol} header.
* <p>
* Note that if the WebSocketHandler passed in at runtime is an instance of
* {@link SubProtocolCapable} then there is not need to explicitly configure
* this property. That is certainly the case with the built-in STOMP over
* WebSocket support. Therefore this property should be configured explicitly
* only if the WebSocketHandler does not implement {@code SubProtocolCapable}.
*/
public void setSupportedProtocols(String... protocols) {
this.supportedProtocols.clear();
Expand Down Expand Up @@ -187,7 +195,10 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
"Response update failed during upgrade to WebSocket, uri=" + request.getURI(), ex);
}

String subProtocol = selectProtocol(headers.getSecWebSocketProtocol());
String subProtocol = selectProtocol(headers.getSecWebSocketProtocol(), wsHandler);
if (logger.isDebugEnabled()) {
logger.debug("Selected sub-protocol: '" + subProtocol + "'");
}

List<WebSocketExtension> requested = headers.getSecWebSocketExtensions();
List<WebSocketExtension> supported = this.requestUpgradeStrategy.getSupportedExtensions(request);
Expand Down Expand Up @@ -246,24 +257,60 @@ protected boolean isValidOrigin(ServerHttpRequest request) {
return true;
}

protected String selectProtocol(List<String> requestedProtocols) {
/**
* Perform the sub-protocol negotiation based on requested and supported sub-protocols.
* For the list of supported sub-protocols, this method first checks if the target
* WebSocketHandler is a {@link SubProtocolCapable} and then also checks if any
* sub-protocols have been explicitly configured with
* {@link #setSupportedProtocols(String...)}.
*
* @param requestedProtocols the requested sub-protocols
* @param webSocketHandler the WebSocketHandler that will be used
* @return the selected protocols or {@code null}
*
* @see #determineHandlerSupportedProtocols(org.springframework.web.socket.WebSocketHandler)
*/
protected String selectProtocol(List<String> requestedProtocols, WebSocketHandler webSocketHandler) {
if (requestedProtocols != null) {
List<String> handlerProtocols = determineHandlerSupportedProtocols(webSocketHandler);
if (logger.isDebugEnabled()) {
logger.debug("Requested sub-protocol(s): " + requestedProtocols
+ ", supported sub-protocol(s): " + this.supportedProtocols);
logger.debug("Requested sub-protocol(s): " + requestedProtocols +
", WebSocketHandler supported sub-protocol(s): " + handlerProtocols +
", configured sub-protocol(s): " + this.supportedProtocols);
}
for (String protocol : requestedProtocols) {
if (handlerProtocols.contains(protocol.toLowerCase())) {
return protocol;
}
if (this.supportedProtocols.contains(protocol.toLowerCase())) {
if (logger.isDebugEnabled()) {
logger.debug("Selected sub-protocol: '" + protocol + "'");
}
return protocol;
}
}
}
return null;
}

/**
* Determine the sub-protocols supported by the given WebSocketHandler by checking
* whether it is an instance of {@link SubProtocolCapable}.
*
* @param handler the handler to check
* @return a list of supported protocols or an empty list
*/
protected final List<String> determineHandlerSupportedProtocols(WebSocketHandler handler) {
List<String> subProtocols = null;
if (handler instanceof SubProtocolCapable) {
subProtocols = ((SubProtocolCapable) handler).getSubProtocols();
}
else if (handler instanceof WebSocketHandlerDecorator) {
WebSocketHandler lastHandler = ((WebSocketHandlerDecorator) handler).getLastHandler();
if (lastHandler instanceof SubProtocolCapable) {
subProtocols = ((SubProtocolCapable) lastHandler).getSubProtocols();;
}
}
return (subProtocols != null) ? subProtocols : Collections.<String>emptyList();
}

/**
* Filter the list of requested WebSocket extensions.
* <p>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.springframework.web.socket.support;

import java.util.List;

/**
* An interface for WebSocket handlers that support sub-protocols as defined in RFC 6455.
*
* @author Rossen Stoyanchev
* @since 4.0
*
* @see <a href="http://tools.ietf.org/html/rfc6455#section-1.9">RFC-6455 section 1.9</a>
*/
public interface SubProtocolCapable {

/**
* Return the list of supported sub-protocols.
*/
List<String> getSubProtocols();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.springframework.web.socket.messaging.config;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -62,7 +61,7 @@ public void minimalRegistration() {


WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(
new String[] {"/foo"}, this.wsHandler, Collections.<String>emptySet(), this.scheduler);
new String[] {"/foo"}, this.wsHandler, this.scheduler);

MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
assertEquals(1, mappings.size());
Expand All @@ -78,7 +77,7 @@ public void customHandshakeHandler() {
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();

WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(
new String[] {"/foo"}, this.wsHandler, Collections.<String>emptySet(), this.scheduler);
new String[] {"/foo"}, this.wsHandler, this.scheduler);

registration.setHandshakeHandler(handshakeHandler);

Expand All @@ -99,7 +98,7 @@ public void customHandshakeHandlerPassedToSockJsService() {
DefaultHandshakeHandler handshakeHandler = new DefaultHandshakeHandler();

WebMvcStompWebSocketEndpointRegistration registration = new WebMvcStompWebSocketEndpointRegistration(
new String[] {"/foo"}, this.wsHandler, Collections.<String>emptySet(), this.scheduler);
new String[] {"/foo"}, this.wsHandler, this.scheduler);

registration.setHandshakeHandler(handshakeHandler);
registration.withSockJS();
Expand Down
Loading

0 comments on commit 4e82416

Please sign in to comment.