Skip to content

Commit

Permalink
Add spring-websocket module tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed May 14, 2013
1 parent 6a5acb9 commit 05084d5
Show file tree
Hide file tree
Showing 52 changed files with 2,927 additions and 387 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ public Cookies getCookies() {
@Override
public MultiValueMap<String, String> getQueryParams() {
if (this.queryParams == null) {
// TODO: extract from query string
this.queryParams = new LinkedMultiValueMap<String, String>(this.servletRequest.getParameterMap().size());
for (String name : this.servletRequest.getParameterMap().keySet()) {
for (String value : this.servletRequest.getParameterValues(name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.springframework.web.util.UriComponentsBuilder;

/**
* Abstract base class for WebSocketConnection managers.
* Abstract base class for WebSocket connection managers.
*
* @author Rossen Stoyanchev
* @since 4.0
Expand Down Expand Up @@ -147,25 +147,25 @@ public void run() {
public final void stop() {
synchronized (this.lifecycleMonitor) {
if (isRunning()) {
stopInternal();
if (logger.isDebugEnabled()) {
logger.debug("Stopping " + this.getClass().getSimpleName());
}
try {
stopInternal();
}
catch (Throwable e) {
logger.error("Failed to stop WebSocket connection", e);
}
finally {
this.isRunning = false;
}
}
}
}

protected void stopInternal() {
if (logger.isDebugEnabled()) {
logger.debug("Stopping " + this.getClass().getSimpleName());
}
try {
if (isConnected()) {
closeConnection();
}
}
catch (Throwable e) {
logger.error("Failed to stop WebSocket connection", e);
}
finally {
this.isRunning = false;
protected void stopInternal() throws Exception {
if (isConnected()) {
closeConnection();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void startInternal() {
}

@Override
public void stopInternal() {
public void stopInternal() throws Exception {
if (this.syncClientLifecycle) {
((SmartLifecycle) client).stop();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ public class StandardWebSocketClient implements WebSocketClient {

private static final Log logger = LogFactory.getLog(StandardWebSocketClient.class);

private static final Set<String> EXCLUDED_HEADERS = new HashSet<String>(
Arrays.asList("Sec-WebSocket-Accept", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key",
"Sec-WebSocket-Protocol", "Sec-WebSocket-Version"));
private WebSocketContainer webSocketContainer;

private WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer();

public WebSocketContainer getWebSocketContainer() {
if (this.webSocketContainer == null) {
this.webSocketContainer = ContainerProvider.getWebSocketContainer();
}
return this.webSocketContainer;
}

public void setWebSocketContainer(WebSocketContainer container) {
this.webSocketContainer = container;
Expand All @@ -72,8 +75,8 @@ public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String ur
}

@Override
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
final HttpHeaders httpHeaders, URI uri) throws WebSocketConnectFailureException {
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders httpHeaders, URI uri)
throws WebSocketConnectFailureException {

StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter();
session.setUri(uri);
Expand All @@ -86,29 +89,7 @@ public WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
if (!protocols.isEmpty()) {
configBuidler.preferredSubprotocols(protocols);
}
configBuidler.configurator(new Configurator() {
@Override
public void beforeRequest(Map<String, List<String>> headers) {
for (String headerName : httpHeaders.keySet()) {
if (!EXCLUDED_HEADERS.contains(headerName)) {
List<String> value = httpHeaders.get(headerName);
if (logger.isTraceEnabled()) {
logger.trace("Adding header [" + headerName + "=" + value + "]");
}
headers.put(headerName, value);
}
}
if (logger.isTraceEnabled()) {
logger.trace("Handshake request headers: " + headers);
}
}
@Override
public void afterResponse(HandshakeResponse handshakeResponse) {
if (logger.isTraceEnabled()) {
logger.trace("Handshake response headers: " + handshakeResponse.getHeaders());
}
}
});
configBuidler.configurator(new StandardWebSocketClientConfigurator(httpHeaders));
}

try {
Expand All @@ -121,4 +102,41 @@ public void afterResponse(HandshakeResponse handshakeResponse) {
}
}


private static class StandardWebSocketClientConfigurator extends Configurator {

private static final Set<String> EXCLUDED_HEADERS = new HashSet<String>(
Arrays.asList("Sec-WebSocket-Accept", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key",
"Sec-WebSocket-Protocol", "Sec-WebSocket-Version"));

private final HttpHeaders httpHeaders;


public StandardWebSocketClientConfigurator(HttpHeaders httpHeaders) {
this.httpHeaders = httpHeaders;
}

@Override
public void beforeRequest(Map<String, List<String>> headers) {
for (String headerName : this.httpHeaders.keySet()) {
if (!EXCLUDED_HEADERS.contains(headerName)) {
List<String> value = this.httpHeaders.get(headerName);
if (logger.isTraceEnabled()) {
logger.trace("Adding header [" + headerName + "=" + value + "]");
}
headers.put(headerName, value);
}
}
if (logger.isTraceEnabled()) {
logger.trace("Handshake request headers: " + headers);
}
}
@Override
public void afterResponse(HandshakeResponse handshakeResponse) {
if (logger.isTraceEnabled()) {
logger.trace("Handshake response headers: " + handshakeResponse.getHeaders());
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@
* @author Rossen Stoyanchev
* @since 4.0
*/
public class EndpointExporter implements InitializingBean, BeanPostProcessor, ApplicationContextAware {
public class ServerEndpointExporter implements InitializingBean, BeanPostProcessor, ApplicationContextAware {

private static final boolean isServletApiPresent =
ClassUtils.isPresent("javax.servlet.ServletContext", EndpointExporter.class.getClassLoader());
ClassUtils.isPresent("javax.servlet.ServletContext", ServerEndpointExporter.class.getClassLoader());

private static Log logger = LogFactory.getLog(ServerEndpointExporter.class);

private static Log logger = LogFactory.getLog(EndpointExporter.class);

private final List<Class<?>> annotatedEndpointClasses = new ArrayList<Class<?>>();

Expand All @@ -63,6 +64,7 @@ public class EndpointExporter implements InitializingBean, BeanPostProcessor, Ap

private ServerContainer serverContainer;


/**
* TODO
* @param annotatedEndpointClasses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@

/**
* An implementation of {@link javax.websocket.server.ServerEndpointConfig} that also
* holds the target {@link javax.websocket.Endpoint} as a reference or a bean name.
*
* <p>
* Beans of this type are detected by {@link EndpointExporter} and
* registered with a Java WebSocket runtime at startup.
* holds the target {@link javax.websocket.Endpoint} provided as a reference or as a bean
* name. Beans of this type are detected by {@link ServerEndpointExporter} and registered
* with a Java WebSocket runtime at startup.
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAware {
public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFactoryAware {

private final String path;

Expand All @@ -65,7 +63,7 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw

private final Map<String, Object> userProperties = new HashMap<String, Object>();

private Configurator configurator = new Configurator() {};
private Configurator configurator = new EndpointRegistrationConfigurator();


/**
Expand All @@ -74,15 +72,15 @@ public class EndpointRegistration implements ServerEndpointConfig, BeanFactoryAw
* @param path
* @param endpointClass
*/
public EndpointRegistration(String path, Class<? extends Endpoint> endpointClass) {
public ServerEndpointRegistration(String path, Class<? extends Endpoint> endpointClass) {
Assert.hasText(path, "path must not be empty");
Assert.notNull(endpointClass, "endpointClass is required");
this.path = path;
this.endpointProvider = new BeanCreatingHandlerProvider<Endpoint>(endpointClass);
this.endpoint = null;
}

public EndpointRegistration(String path, Endpoint endpoint) {
public ServerEndpointRegistration(String path, Endpoint endpoint) {
Assert.hasText(path, "path must not be empty");
Assert.notNull(endpoint, "endpoint is required");
this.path = path;
Expand Down Expand Up @@ -152,38 +150,9 @@ public List<Class<? extends Decoder>> getDecoders() {
return this.decoders;
}

/**
* The {@link Configurator#getEndpointInstance(Class)} method is always ignored.
*/
public void setConfigurator(Configurator configurator) {
this.configurator = configurator;
}

@Override
public Configurator getConfigurator() {
return new Configurator() {
@SuppressWarnings("unchecked")
@Override
public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
return (T) EndpointRegistration.this.getEndpoint();
}
@Override
public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
EndpointRegistration.this.configurator.modifyHandshake(sec, request, response);
}
@Override
public boolean checkOrigin(String originHeaderValue) {
return EndpointRegistration.this.configurator.checkOrigin(originHeaderValue);
}
@Override
public String getNegotiatedSubprotocol(List<String> supported, List<String> requested) {
return EndpointRegistration.this.configurator.getNegotiatedSubprotocol(supported, requested);
}
@Override
public List<Extension> getNegotiatedExtensions(List<Extension> installed, List<Extension> requested) {
return EndpointRegistration.this.configurator.getNegotiatedExtensions(installed, requested);
}
};
return this.configurator;
}

@Override
Expand All @@ -193,4 +162,50 @@ public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
}
}

protected void modifyHandshake(HandshakeRequest request, HandshakeResponse response) {
this.configurator.modifyHandshake(this, request, response);
}

protected boolean checkOrigin(String originHeaderValue) {
return this.configurator.checkOrigin(originHeaderValue);
}

protected String getNegotiatedSubprotocol(List<String> supported, List<String> requested) {
return this.configurator.getNegotiatedSubprotocol(supported, requested);
}

protected List<Extension> getNegotiatedExtensions(List<Extension> installed, List<Extension> requested) {
return this.configurator.getNegotiatedExtensions(installed, requested);
}


private class EndpointRegistrationConfigurator extends Configurator {

@SuppressWarnings("unchecked")
@Override
public <T> T getEndpointInstance(Class<T> clazz) throws InstantiationException {
return (T) ServerEndpointRegistration.this.getEndpoint();
}

@Override
public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
super.modifyHandshake(sec, request, response);
}

@Override
public boolean checkOrigin(String originHeaderValue) {
return super.checkOrigin(originHeaderValue);
}

@Override
public String getNegotiatedSubprotocol(List<String> supported, List<String> requested) {
return super.getNegotiatedSubprotocol(supported, requested);
}

@Override
public List<Extension> getNegotiatedExtensions(List<Extension> installed, List<Extension> requested) {
return super.getNegotiatedExtensions(installed, requested);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* using its setters allows configuring the {@code ServerContainer} through Spring
* configuration. This is useful even if the ServerContainer is not injected into any
* other bean. For example, an application can configure a {@link DefaultHandshakeHandler}
* , a {@link SockJsService}, or {@link EndpointExporter}, and separately declare this
* , a {@link SockJsService}, or {@link ServerEndpointExporter}, and separately declare this
* FactoryBean in order to customize the properties of the (one and only)
* {@code ServerContainer} instance.
*
Expand All @@ -44,9 +44,6 @@
public class ServletServerContainerFactoryBean
implements FactoryBean<WebSocketContainer>, InitializingBean, ServletContextAware {

private static final String SERVER_CONTAINER_ATTR_NAME = "javax.websocket.server.ServerContainer";


private Long asyncSendTimeout;

private Long maxSessionIdleTimeout;
Expand Down Expand Up @@ -92,7 +89,7 @@ public Integer getMaxBinaryMessageBufferSize() {

@Override
public void setServletContext(ServletContext servletContext) {
this.serverContainer = (ServerContainer) servletContext.getAttribute(SERVER_CONTAINER_ATTR_NAME);
this.serverContainer = (ServerContainer) servletContext.getAttribute("javax.websocket.server.ServerContainer");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import org.springframework.web.context.WebApplicationContext;

/**
* This should be used in conjuction with {@link ServerEndpoint @ServerEndpoint} classes.
* This should be used in conjunction with {@link ServerEndpoint @ServerEndpoint} classes.
*
* <p>For {@link javax.websocket.Endpoint}, see {@link EndpointExporter}.
* <p>For {@link javax.websocket.Endpoint}, see {@link ServerEndpointExporter}.
*
* @author Rossen Stoyanchev
* @since 4.0
Expand All @@ -56,7 +56,7 @@ public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationExc
}
return wac.getAutowireCapableBeanFactory().createBean(endpointClass);
}
if (beans.size() == 1) {
else if (beans.size() == 1) {
if (logger.isTraceEnabled()) {
logger.trace("Using @ServerEndpoint singleton " + beans.keySet().iterator().next());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.springframework.web.socket.server.endpoint;


public class Test {

}
Loading

0 comments on commit 05084d5

Please sign in to comment.