Skip to content

Commit

Permalink
Protect STOMP passcode from showing up in logs
Browse files Browse the repository at this point in the history
Issue: SRP-10868
  • Loading branch information
rstoyanchev committed Aug 29, 2013
1 parent 1472e97 commit 80812d3
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
*/
public class StompHeaderAccessor extends SimpMessageHeaderAccessor {

private static final AtomicLong messageIdCounter = new AtomicLong();

// STOMP header names

public static final String STOMP_ID_HEADER = "id";
Expand Down Expand Up @@ -83,10 +85,9 @@ public class StompHeaderAccessor extends SimpMessageHeaderAccessor {

// Other header names

public static final String COMMAND_HEADER = "stompCommand";
private static final String COMMAND_HEADER = "stompCommand";


private static final AtomicLong messageIdCounter = new AtomicLong();
private static final String CREDENTIALS_HEADER = "stompCredentials";


/**
Expand Down Expand Up @@ -128,6 +129,12 @@ else if (StompCommand.MESSAGE.equals(command)) {
super.setSubscriptionId(values.get(0));
}
}
else if (StompCommand.CONNECT.equals(command)) {
if (!StringUtils.isEmpty(getPasscode())) {
setHeader(CREDENTIALS_HEADER, new StompPasscode(getPasscode()));
setPasscode("PROTECTED");
}
}
}

/**
Expand Down Expand Up @@ -197,6 +204,18 @@ public Map<String, List<String>> toNativeHeaderMap() {
return result;
}

public Map<String, List<String>> toStompHeaderMap() {
if (StompCommand.CONNECT.equals(getCommand())) {
StompPasscode credentials = (StompPasscode) getHeader(CREDENTIALS_HEADER);
if (credentials != null) {
Map<String, List<String>> headers = toNativeHeaderMap();
headers.put(STOMP_PASSCODE_HEADER, Arrays.asList(credentials.passcode));
return headers;
}
}
return toNativeHeaderMap();
}

public void setCommandIfNotSet(StompCommand command) {
if (getCommand() == null) {
setHeader(COMMAND_HEADER, command);
Expand Down Expand Up @@ -338,4 +357,18 @@ public void setVersion(String version) {
setNativeHeader(STOMP_VERSION_HEADER, version);
}


private static class StompPasscode {

private final String passcode;

public StompPasscode(String passcode) {
this.passcode = passcode;
}

@Override
public String toString() {
return "[PROTECTED]";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ public byte[] fromMessage(Message<?> message) {
try {
out.write(stompHeaders.getCommand().toString().getBytes("UTF-8"));
out.write(LF);
for (Entry<String, List<String>> entry : stompHeaders.toNativeHeaderMap().entrySet()) {
for (Entry<String, List<String>> entry : stompHeaders.toStompHeaderMap().entrySet()) {
String key = entry.getKey();
key = replaceAllOutbound(key);
for (String value : entry.getValue()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ public void createWithMessageFrameNativeHeaders() {
assertEquals("s1", headers.getSubscriptionId());
}

@Test
public void createWithConnectNativeHeaders() {

MultiValueMap<String, String> extHeaders = new LinkedMultiValueMap<>();
extHeaders.add(StompHeaderAccessor.STOMP_LOGIN_HEADER, "joe");
extHeaders.add(StompHeaderAccessor.STOMP_PASSCODE_HEADER, "joe123");

StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT, extHeaders);

assertEquals(StompCommand.CONNECT, headers.getCommand());
assertEquals(SimpMessageType.CONNECT, headers.getMessageType());
assertNotNull(headers.getHeader("stompCredentials"));
assertEquals("joe", headers.getLogin());
assertEquals("PROTECTED", headers.getPasscode());

Map<String, List<String>> output = headers.toStompHeaderMap();
assertEquals("joe", output.get(StompHeaderAccessor.STOMP_LOGIN_HEADER).get(0));
assertEquals("joe123", output.get(StompHeaderAccessor.STOMP_PASSCODE_HEADER).get(0));
}

@Test
public void toNativeHeadersSubscribe() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ protected int getPayloadSize() {

@Override
protected String toStringPayload() {
return (getPayloadSize() > 80) ? getPayload().substring(0, 80) + "..." : getPayload();
return (getPayloadSize() > 10) ? getPayload().substring(0, 10) + ".." : getPayload();
}

}

0 comments on commit 80812d3

Please sign in to comment.