Skip to content

Commit

Permalink
[improve][broker]add ServerCnx state check before server handle reque…
Browse files Browse the repository at this point in the history
…st (apache#17084)
  • Loading branch information
HQebupt authored Aug 19, 2022
1 parent 14912a6 commit 694aa13
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ private CompletableFuture<Boolean> isTopicOperationAllowed(TopicName topicName,

@Override
protected void handleLookup(CommandLookupTopic lookup) {
checkArgument(state == State.Connected);
final long requestId = lookup.getRequestId();
final boolean authoritative = lookup.isAuthoritative();

Expand Down Expand Up @@ -504,6 +505,7 @@ protected void handleLookup(CommandLookupTopic lookup) {

@Override
protected void handlePartitionMetadataRequest(CommandPartitionedTopicMetadata partitionMetadata) {
checkArgument(state == State.Connected);
final long requestId = partitionMetadata.getRequestId();
if (log.isDebugEnabled()) {
log.debug("[{}] Received PartitionMetadataLookup from {} for {}", partitionMetadata.getTopic(),
Expand Down Expand Up @@ -580,6 +582,7 @@ protected void handlePartitionMetadataRequest(CommandPartitionedTopicMetadata pa

@Override
protected void handleConsumerStats(CommandConsumerStats commandConsumerStats) {
checkArgument(state == State.Connected);
if (log.isDebugEnabled()) {
log.debug("Received CommandConsumerStats call from {}", remoteAddress);
}
Expand Down Expand Up @@ -1988,6 +1991,7 @@ private CompletableFuture<Boolean> isNamespaceOperationAllowed(NamespaceName nam

@Override
protected void handleGetTopicsOfNamespace(CommandGetTopicsOfNamespace commandGetTopicsOfNamespace) {
checkArgument(state == State.Connected);
final long requestId = commandGetTopicsOfNamespace.getRequestId();
final String namespace = commandGetTopicsOfNamespace.getNamespace();
final CommandGetTopicsOfNamespace.Mode mode = commandGetTopicsOfNamespace.getMode();
Expand Down Expand Up @@ -2076,6 +2080,7 @@ protected void handleGetTopicsOfNamespace(CommandGetTopicsOfNamespace commandGet

@Override
protected void handleGetSchema(CommandGetSchema commandGetSchema) {
checkArgument(state == State.Connected);
if (log.isDebugEnabled()) {
if (commandGetSchema.hasSchemaVersion()) {
log.debug("Received CommandGetSchema call from {}, schemaVersion: {}, topic: {}, requestId: {}",
Expand Down Expand Up @@ -2123,6 +2128,7 @@ remoteAddress, new String(commandGetSchema.getSchemaVersion()),

@Override
protected void handleGetOrCreateSchema(CommandGetOrCreateSchema commandGetOrCreateSchema) {
checkArgument(state == State.Connected);
if (log.isDebugEnabled()) {
log.debug("Received CommandGetOrCreateSchema call from {}", remoteAddress);
}
Expand Down Expand Up @@ -2158,6 +2164,7 @@ protected void handleGetOrCreateSchema(CommandGetOrCreateSchema commandGetOrCrea

@Override
protected void handleTcClientConnectRequest(CommandTcClientConnectRequest command) {
checkArgument(state == State.Connected);
final long requestId = command.getRequestId();
final TransactionCoordinatorID tcId = TransactionCoordinatorID.get(command.getTcId());
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -2219,6 +2226,7 @@ private Throwable handleTxnException(Throwable ex, String op, long requestId) {
}
@Override
protected void handleNewTxn(CommandNewTxn command) {
checkArgument(state == State.Connected);
final long requestId = command.getRequestId();
final TransactionCoordinatorID tcId = TransactionCoordinatorID.get(command.getTcId());
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -2260,6 +2268,7 @@ protected void handleNewTxn(CommandNewTxn command) {

@Override
protected void handleAddPartitionToTxn(CommandAddPartitionToTxn command) {
checkArgument(state == State.Connected);
final TxnID txnID = new TxnID(command.getTxnidMostBits(), command.getTxnidLeastBits());
final TransactionCoordinatorID tcId = TransactionCoordinatorID.get(command.getTxnidMostBits());
final long requestId = command.getRequestId();
Expand Down Expand Up @@ -2297,6 +2306,7 @@ protected void handleAddPartitionToTxn(CommandAddPartitionToTxn command) {

@Override
protected void handleEndTxn(CommandEndTxn command) {
checkArgument(state == State.Connected);
final long requestId = command.getRequestId();
final int txnAction = command.getTxnAction().getValue();
TxnID txnID = new TxnID(command.getTxnidMostBits(), command.getTxnidLeastBits());
Expand Down Expand Up @@ -2327,6 +2337,7 @@ protected void handleEndTxn(CommandEndTxn command) {

@Override
protected void handleEndTxnOnPartition(CommandEndTxnOnPartition command) {
checkArgument(state == State.Connected);
final long requestId = command.getRequestId();
final String topic = command.getTopic();
final int txnAction = command.getTxnAction().getValue();
Expand Down Expand Up @@ -2397,6 +2408,7 @@ protected void handleEndTxnOnPartition(CommandEndTxnOnPartition command) {

@Override
protected void handleEndTxnOnSubscription(CommandEndTxnOnSubscription command) {
checkArgument(state == State.Connected);
final long requestId = command.getRequestId();
final long txnidMostBits = command.getTxnidMostBits();
final long txnidLeastBits = command.getTxnidLeastBits();
Expand Down Expand Up @@ -2503,6 +2515,7 @@ private CompletableFuture<SchemaVersion> tryAddSchema(Topic topic, SchemaData sc

@Override
protected void handleAddSubscriptionToTxn(CommandAddSubscriptionToTxn command) {
checkArgument(state == State.Connected);
final TxnID txnID = new TxnID(command.getTxnidMostBits(), command.getTxnidLeastBits());
final long requestId = command.getRequestId();
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -2541,6 +2554,7 @@ protected void handleAddSubscriptionToTxn(CommandAddSubscriptionToTxn command) {
}

protected void handleCommandWatchTopicList(CommandWatchTopicList commandWatchTopicList) {
checkArgument(state == State.Connected);
final long requestId = commandWatchTopicList.getRequestId();
final long watcherId = commandWatchTopicList.getWatcherId();
final NamespaceName namespaceName = NamespaceName.get(commandWatchTopicList.getNamespace());
Expand Down Expand Up @@ -2590,6 +2604,7 @@ protected void handleCommandWatchTopicList(CommandWatchTopicList commandWatchTop
}

protected void handleCommandWatchTopicListClose(CommandWatchTopicListClose commandWatchTopicListClose) {
checkArgument(state == State.Connected);
topicListService.handleWatchTopicListClose(commandWatchTopicListClose);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2153,4 +2153,139 @@ public void testHandleAuthResponseWithoutClientVersion() {
verify(authResponse, times(1)).hasClientVersion();
verify(authResponse, times(0)).getClientVersion();
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleLookup() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleLookup(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandlePartitionMetadataRequest() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handlePartitionMetadataRequest(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleConsumerStats() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleConsumerStats(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleGetTopicsOfNamespace() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleGetTopicsOfNamespace(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleGetSchema() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleGetSchema(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleGetOrCreateSchema() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleGetOrCreateSchema(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleTcClientConnectRequest() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleTcClientConnectRequest(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleNewTxn() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleNewTxn(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleAddPartitionToTxn() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleAddPartitionToTxn(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleEndTxn() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleEndTxn(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleEndTxnOnPartition() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleEndTxnOnPartition(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleEndTxnOnSubscription() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleEndTxnOnSubscription(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleAddSubscriptionToTxn() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleAddSubscriptionToTxn(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleCommandWatchTopicList() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleCommandWatchTopicList(any());
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void shouldFailHandleCommandWatchTopicListClose() throws Exception {
ServerCnx serverCnx = mock(ServerCnx.class, CALLS_REAL_METHODS);
Field stateUpdater = ServerCnx.class.getDeclaredField("state");
stateUpdater.setAccessible(true);
stateUpdater.set(serverCnx, ServerCnx.State.Failed);
serverCnx.handleCommandWatchTopicListClose(any());
}
}

0 comments on commit 694aa13

Please sign in to comment.