Skip to content

Commit

Permalink
[FLINK-4307] [streaming API] Restore ListState behavior for user-faci…
Browse files Browse the repository at this point in the history
…ng ListStates
  • Loading branch information
StephanEwen committed Aug 2, 2016
1 parent 31837a7 commit d5a06b4
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
requireNonNull(stateProperties, "The state properties must not be null");
try {
stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
return operator.getPartitionedState(stateProperties);
ListState<T> originalState = operator.getPartitionedState(stateProperties);
return new UserFacingListState<T>(originalState);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.operators;

import org.apache.flink.api.common.state.ListState;

import java.util.Collections;

/**
* Simple wrapper list state that exposes empty state properly as an empty list.
*
* @param <T> The type of elements in the list state.
*/
class UserFacingListState<T> implements ListState<T> {

private final ListState<T> originalState;

private final Iterable<T> emptyState = Collections.emptyList();

UserFacingListState(ListState<T> originalState) {
this.originalState = originalState;
}

// ------------------------------------------------------------------------

@Override
public Iterable<T> get() throws Exception {
Iterable<T> original = originalState.get();
return original != null ? original : emptyState;
}

@Override
public void add(T value) throws Exception {
originalState.add(value);
}

@Override
public void clear() {
originalState.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.common.typeutils.base.VoidSerializer;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.execution.Environment;

import org.apache.flink.runtime.state.memory.MemListState;
import org.junit.Test;

import org.mockito.invocation.InvocationOnMock;
Expand All @@ -54,7 +58,7 @@ public void testValueStateInstantiation() throws Exception {
final AtomicReference<Object> descriptorCapture = new AtomicReference<>();

StreamingRuntimeContext context = new StreamingRuntimeContext(
createMockOp(descriptorCapture, config),
createDescriptorCapturingMockOp(descriptorCapture, config),
createMockEnvironment(),
Collections.<String, Accumulator<?, ?>>emptyMap());

Expand All @@ -78,7 +82,7 @@ public void testReduceingStateInstantiation() throws Exception {
final AtomicReference<Object> descriptorCapture = new AtomicReference<>();

StreamingRuntimeContext context = new StreamingRuntimeContext(
createMockOp(descriptorCapture, config),
createDescriptorCapturingMockOp(descriptorCapture, config),
createMockEnvironment(),
Collections.<String, Accumulator<?, ?>>emptyMap());

Expand Down Expand Up @@ -107,7 +111,7 @@ public void testListStateInstantiation() throws Exception {
final AtomicReference<Object> descriptorCapture = new AtomicReference<>();

StreamingRuntimeContext context = new StreamingRuntimeContext(
createMockOp(descriptorCapture, config),
createDescriptorCapturingMockOp(descriptorCapture, config),
createMockEnvironment(),
Collections.<String, Accumulator<?, ?>>emptyMap());

Expand All @@ -121,13 +125,29 @@ public void testListStateInstantiation() throws Exception {
assertTrue(serializer instanceof KryoSerializer);
assertTrue(((KryoSerializer<?>) serializer).getKryo().getRegistration(Path.class).getId() > 0);
}

@Test
public void testListStateReturnsEmptyListByDefault() throws Exception {

StreamingRuntimeContext context = new StreamingRuntimeContext(
createPlainMockOp(),
createMockEnvironment(),
Collections.<String, Accumulator<?, ?>>emptyMap());

ListStateDescriptor<String> descr = new ListStateDescriptor<>("name", String.class);
ListState<String> state = context.getListState(descr);

Iterable<String> value = state.get();
assertNotNull(value);
assertFalse(value.iterator().hasNext());
}

// ------------------------------------------------------------------------
//
// ------------------------------------------------------------------------

@SuppressWarnings("unchecked")
private static AbstractStreamOperator<?> createMockOp(
private static AbstractStreamOperator<?> createDescriptorCapturingMockOp(
final AtomicReference<Object> ref, final ExecutionConfig config) throws Exception {

AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
Expand All @@ -145,6 +165,27 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable {

return operatorMock;
}

@SuppressWarnings("unchecked")
private static AbstractStreamOperator<?> createPlainMockOp() throws Exception {

AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
when(operatorMock.getExecutionConfig()).thenReturn(new ExecutionConfig());

when(operatorMock.getPartitionedState(any(ListStateDescriptor.class))).thenAnswer(
new Answer<ListState<String>>() {

@Override
public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
ListStateDescriptor<String> descr =
(ListStateDescriptor<String>) invocationOnMock.getArguments()[0];
return new MemListState<String, Void, String>(
StringSerializer.INSTANCE, VoidSerializer.INSTANCE, descr);
}
});

return operatorMock;
}

private static Environment createMockEnvironment() {
Environment env = mock(Environment.class);
Expand Down

0 comments on commit d5a06b4

Please sign in to comment.