Skip to content

Commit

Permalink
Add StompProtocolHandler tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed Aug 29, 2013
1 parent 364bc35 commit 39ff1e2
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.messaging.handler.websocket.SubProtocolHandler;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.MutableUserQueueSuffixResolver;
import org.springframework.messaging.simp.handler.SimpleUserQueueSuffixResolver;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.web.socket.CloseStatus;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class StompProtocolHandler implements SubProtocolHandler {

private final StompMessageConverter stompMessageConverter = new StompMessageConverter();

private MutableUserQueueSuffixResolver queueSuffixResolver;
private MutableUserQueueSuffixResolver queueSuffixResolver = new SimpleUserQueueSuffixResolver();


/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.Message;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.simp.AbstractWebSocketIntegrationTests;
import org.springframework.messaging.simp.JettyTestServer;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.stomp.StompMessageConverter;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder;
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
import org.springframework.stereotype.Controller;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
Expand Down Expand Up @@ -76,16 +73,13 @@ public void sendMessage() throws Exception {
this.server.init(cxt);
this.server.start();

StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
headers.setDestination("/app/foo");
Message<byte[]> message = MessageBuilder.withPayloadAndHeaders(new byte[0], headers).build();
byte[] bytes = new StompMessageConverter().fromMessage(message);
final TextMessage webSocketMessage = new TextMessage(new String(bytes));
final TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND)
.headers("destination:/app/foo").build();

WebSocketHandler clientHandler = new TextWebSocketHandlerAdapter() {
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
session.sendMessage(webSocketMessage);
session.sendMessage(textMessage);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.web.socket.TextMessage;

import static org.junit.Assert.*;

Expand All @@ -41,14 +42,17 @@ public void setup() {
this.converter = new StompMessageConverter();
}

@SuppressWarnings("unchecked")
@Test
public void connectFrame() throws Exception {

String accept = "accept-version:1.1\n";
String host = "host:github.org\n";
String frame = "\n\n\nCONNECT\n" + accept + host + "\n";
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
String accept = "accept-version:1.1";
String host = "host:github.org";

TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();

@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());

assertEquals(0, message.getPayload().length);

Expand Down Expand Up @@ -80,11 +84,14 @@ public void connectFrame() throws Exception {
@Test
public void connectWithEscapes() throws Exception {

String accept = "accept-version:1.1\n";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org\n";
String frame = "CONNECT\n" + accept + host + "\n";
String accept = "accept-version:1.1";
String host = "ho\\c\\ns\\rt:st\\nomp.gi\\cthu\\b.org";

TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT)
.headers(accept, host).build();

@SuppressWarnings("unchecked")
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(frame.getBytes("UTF-8"));
Message<byte[]> message = (Message<byte[]>) this.converter.toMessage(textMessage.getPayload());

assertEquals(0, message.getPayload().length);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.messaging.simp.stomp;

import java.util.Arrays;
import java.util.HashSet;

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.support.TestPrincipal;
import org.springframework.web.socket.support.TestWebSocketSession;

import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

/**
* Test fixture for {@link StompProtocolHandler} tests.
*
* @author Rossen Stoyanchev
*/
public class StompProtocolHandlerTests {

private StompProtocolHandler stompHandler;

private TestWebSocketSession session;

private MessageChannel channel;

private ArgumentCaptor<Message> messageCaptor;


@Before
public void setup() {
this.stompHandler = new StompProtocolHandler();
this.channel = Mockito.mock(MessageChannel.class);
this.messageCaptor = ArgumentCaptor.forClass(Message.class);

this.session = new TestWebSocketSession();
this.session.setId("s1");
this.session.setPrincipal(new TestPrincipal("joe"));
}

@Test
public void handleConnect() {

TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
"login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();

this.stompHandler.handleMessageFromClient(this.session, textMessage, this.channel);

verify(this.channel).send(this.messageCaptor.capture());
Message<?> actual = this.messageCaptor.getValue();
assertNotNull(actual);

StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual);
assertEquals(StompCommand.CONNECT, headers.getCommand());
assertEquals("s1", headers.getSessionId());
assertEquals("joe", headers.getUser().getName());
assertEquals("guest", headers.getLogin());
assertEquals("PROTECTED", headers.getPasscode());
assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat());
assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion());

// Check CONNECTED reply

assertEquals(1, this.session.getSentMessages().size());
textMessage = (TextMessage) this.session.getSentMessages().get(0);
Message<?> message = new StompMessageConverter().toMessage(textMessage.getPayload());
StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(message);

assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
assertEquals("1.1", replyHeaders.getVersion());
assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
assertEquals("s1", replyHeaders.getNativeHeader("queue-suffix").get(0));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.messaging.simp.stomp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.springframework.web.socket.TextMessage;


/**
* A builder for creating WebSocket messages with STOMP frame content.
*
* @author Rossen Stoyanchev
*/
public class StompTextMessageBuilder {

private StompCommand command;

private final List<String> headerLines = new ArrayList<String>();

private String body;


private StompTextMessageBuilder(StompCommand command) {
this.command = command;
}

public static StompTextMessageBuilder create(StompCommand command) {
return new StompTextMessageBuilder(command);
}

public StompTextMessageBuilder headers(String... headerLines) {
this.headerLines.addAll(Arrays.asList(headerLines));
return this;
}

public StompTextMessageBuilder body(String body) {
this.body = body;
return this;
}

public TextMessage build() {
StringBuilder sb = new StringBuilder(this.command.name()).append("\n");
for (String line : this.headerLines) {
sb.append(line).append("\n");
}
sb.append("\n");
if (this.body != null) {
sb.append(this.body);
}
sb.append("\u0000");
return new TextMessage(sb.toString());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@
package org.springframework.web.socket.server.config;

import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
import org.springframework.web.socket.JettyTestServer;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.WebSocketHandlerAdapter;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;

import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;
import static org.junit.Assert.*;


/**
Expand All @@ -63,13 +63,10 @@ public void registerWebSocketHandler() throws Exception {
this.server.init(cxt);
this.server.start();

WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class);
WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class);
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/ws");

this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/ws");

verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class));
verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class));
TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
}

@Test
Expand All @@ -81,13 +78,10 @@ public void registerWebSocketHandlerWithSockJS() throws Exception {
this.server.init(cxt);
this.server.start();

WebSocketHandler clientHandler = Mockito.mock(WebSocketHandler.class);
WebSocketHandler serverHandler = cxt.getBean(WebSocketHandler.class);

this.webSocketClient.doHandshake(clientHandler, getWsBaseUrl() + "/sockjs/websocket");
this.webSocketClient.doHandshake(new WebSocketHandlerAdapter(), getWsBaseUrl() + "/sockjs/websocket");

verify(serverHandler).afterConnectionEstablished(any(WebSocketSession.class));
verify(clientHandler).afterConnectionEstablished(any(WebSocketSession.class));
TestWebSocketHandler serverHandler = cxt.getBean(TestWebSocketHandler.class);
assertTrue(serverHandler.latch.await(2, TimeUnit.SECONDS));
}


Expand All @@ -110,8 +104,18 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
}

@Bean
public WebSocketHandler serverHandler() {
return Mockito.mock(WebSocketHandler.class);
public TestWebSocketHandler serverHandler() {
return new TestWebSocketHandler();
}
}

private static class TestWebSocketHandler extends WebSocketHandlerAdapter {

private CountDownLatch latch = new CountDownLatch(1);

@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.latch.countDown();
}
}

Expand Down
Loading

0 comments on commit 39ff1e2

Please sign in to comment.