Skip to content

Commit

Permalink
[FLINK-17928][checkpointing] Fix ChannelStateHandle size
Browse files Browse the repository at this point in the history
Store state size explicitly because underlying state
handle may be shared.
  • Loading branch information
rkhachatryan authored and pnowojski committed May 28, 2020
1 parent 1b824d5 commit f89c606
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.CheckpointStreamFactory.CheckpointStateOutputStream;
import org.apache.flink.runtime.state.InputChannelStateHandle;
Expand Down Expand Up @@ -61,8 +62,8 @@ class ChannelStateCheckpointWriter {
private final DataOutputStream dataStream;
private final CheckpointStateOutputStream checkpointStream;
private final ChannelStateWriteResult result;
private final Map<InputChannelInfo, List<Long>> inputChannelOffsets = new HashMap<>();
private final Map<ResultSubpartitionInfo, List<Long>> resultSubpartitionOffsets = new HashMap<>();
private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>();
private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets = new HashMap<>();
private final ChannelStateSerializer serializer;
private final long checkpointId;
private boolean allInputsReceived = false;
Expand Down Expand Up @@ -115,17 +116,19 @@ void writeOutput(ResultSubpartitionInfo info, Buffer... flinkBuffers) throws Exc
write(resultSubpartitionOffsets, info, flinkBuffers, !allOutputsReceived);
}

private <K> void write(Map<K, List<Long>> offsets, K key, Buffer[] flinkBuffers, boolean precondition) throws Exception {
private <K> void write(Map<K, StateContentMetaInfo> offsets, K key, Buffer[] flinkBuffers, boolean precondition) throws Exception {
try {
if (result.isDone()) {
return;
}
runWithChecks(() -> {
checkState(precondition);
offsets
.computeIfAbsent(key, unused -> new ArrayList<>())
.add(checkpointStream.getPos());
long offset = checkpointStream.getPos();
serializer.writeData(dataStream, flinkBuffers);
long size = checkpointStream.getPos() - offset;
offsets
.computeIfAbsent(key, unused -> new StateContentMetaInfo())
.withDataAdded(offset, size);
});
} finally {
for (Buffer flinkBuffer : flinkBuffers) {
Expand Down Expand Up @@ -179,10 +182,10 @@ private void doComplete(boolean precondition, RunnableWithException complete, Ru
private <I, H extends AbstractChannelStateHandle<I>> void complete(
StreamStateHandle underlying,
CompletableFuture<Collection<H>> future,
Map<I, List<Long>> offsets,
Map<I, StateContentMetaInfo> offsets,
HandleFactory<I, H> handleFactory) throws IOException {
final Collection<H> handles = new ArrayList<>();
for (Map.Entry<I, List<Long>> e : offsets.entrySet()) {
for (Map.Entry<I, StateContentMetaInfo> e : offsets.entrySet()) {
handles.add(createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
}
future.complete(handles);
Expand All @@ -193,15 +196,19 @@ private <I, H extends AbstractChannelStateHandle<I>> H createHandle(
HandleFactory<I, H> handleFactory,
StreamStateHandle underlying,
I channelInfo,
List<Long> offsets) throws IOException {
StateContentMetaInfo contentMetaInfo) throws IOException {
Optional<byte[]> bytes = underlying.asBytesIfInMemory(); // todo: consider restructuring channel state and removing this method: https://issues.apache.org/jira/browse/FLINK-17972
if (bytes.isPresent()) {
StreamStateHandle extracted = new ByteStreamStateHandle(
randomUUID().toString(),
serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
return handleFactory.create(
channelInfo,
new ByteStreamStateHandle(randomUUID().toString(), serializer.extractAndMerge(bytes.get(), offsets)),
singletonList(serializer.getHeaderLength()));
extracted,
singletonList(serializer.getHeaderLength()),
extracted.getStateSize());
} else {
return handleFactory.create(channelInfo, underlying, offsets);
return handleFactory.create(channelInfo, underlying, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
}
}

Expand All @@ -221,7 +228,7 @@ public void fail(Throwable e) throws Exception {
}

private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
H create(I info, StreamStateHandle underlying, List<Long> offsets);
H create(I info, StreamStateHandle underlying, List<Long> offsets, long size);

HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL = InputChannelStateHandle::new;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
Expand Down Expand Up @@ -51,7 +52,7 @@ ResultSubpartitionStateHandle deserializeResultSubpartitionStateHandle(

return deserializeChannelStateHandle(
is -> new ResultSubpartitionInfo(is.readInt(), is.readInt()),
(streamStateHandle, longs, info) -> new ResultSubpartitionStateHandle(info, streamStateHandle, longs),
(streamStateHandle, contentMetaInfo, info) -> new ResultSubpartitionStateHandle(info, streamStateHandle, contentMetaInfo),
dis,
context);
}
Expand All @@ -69,7 +70,7 @@ InputChannelStateHandle deserializeInputChannelStateHandle(

return deserializeChannelStateHandle(
is -> new InputChannelInfo(is.readInt(), is.readInt()),
(streamStateHandle, longs, inputChannelInfo) -> new InputChannelStateHandle(inputChannelInfo, streamStateHandle, longs),
(streamStateHandle, contentMetaInfo, inputChannelInfo) -> new InputChannelStateHandle(inputChannelInfo, streamStateHandle, contentMetaInfo),
dis,
context);
}
Expand All @@ -83,12 +84,13 @@ private static <I> void serializeChannelStateHandle(
for (long offset : handle.getOffsets()) {
dos.writeLong(offset);
}
dos.writeLong(handle.getStateSize());
serializeStreamStateHandle(handle.getDelegate(), dos);
}

private static <Info, Handle extends AbstractChannelStateHandle<Info>> Handle deserializeChannelStateHandle(
FunctionWithException<DataInputStream, Info, IOException> infoReader,
TriFunctionWithException<StreamStateHandle, List<Long>, Info, Handle, IOException> handleBuilder,
TriFunctionWithException<StreamStateHandle, StateContentMetaInfo, Info, Handle, IOException> handleBuilder,
DataInputStream dis,
MetadataV2V3SerializerBase.DeserializationContext context) throws IOException {

Expand All @@ -98,6 +100,7 @@ private static <Info, Handle extends AbstractChannelStateHandle<Info>> Handle de
for (int i = 0; i < offsetsSize; i++) {
offsets.add(dis.readLong());
}
return handleBuilder.apply(deserializeStreamStateHandle(dis, context), offsets, info);
final long size = dis.readLong();
return handleBuilder.apply(deserializeStreamStateHandle(dis, context), new StateContentMetaInfo(offsets, size), info);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.flink.annotation.Internal;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

Expand All @@ -39,11 +41,13 @@ public abstract class AbstractChannelStateHandle<Info> implements StateObject {
* Start offsets in a {@link org.apache.flink.core.fs.FSDataInputStream stream} {@link StreamStateHandle#openInputStream obtained} from {@link #delegate}.
*/
private final List<Long> offsets;
private final long size;

AbstractChannelStateHandle(StreamStateHandle delegate, List<Long> offsets, Info info) {
AbstractChannelStateHandle(StreamStateHandle delegate, List<Long> offsets, Info info, long size) {
this.info = checkNotNull(info);
this.delegate = checkNotNull(delegate);
this.offsets = checkNotNull(offsets);
this.size = size;
}

@Override
Expand All @@ -53,7 +57,7 @@ public void discardState() throws Exception {

@Override
public long getStateSize() {
return delegate.getStateSize();
return size; // can not rely on delegate.getStateSize because it can be shared
}

public List<Long> getOffsets() {
Expand Down Expand Up @@ -84,4 +88,35 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(info, delegate, offsets);
}

/**
* Describes the underlying content.
*/
public static class StateContentMetaInfo {
private final List<Long> offsets;
private long size = 0;

public StateContentMetaInfo() {
this(new ArrayList<>(), 0);
}

public StateContentMetaInfo(List<Long> offsets, long size) {
this.offsets = offsets;
this.size = size;
}

public void withDataAdded(long offset, long size) {
this.offsets.add(offset);
this.size += size;
}

public List<Long> getOffsets() {
return Collections.unmodifiableList(offsets);
}

public long getSize() {
return size;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ public class InputChannelStateHandle extends AbstractChannelStateHandle<InputCha

private static final long serialVersionUID = 1L;

public InputChannelStateHandle(InputChannelInfo info, StreamStateHandle delegate, StateContentMetaInfo contentMetaInfo) {
this(info, delegate, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
}

public InputChannelStateHandle(InputChannelInfo info, StreamStateHandle delegate, List<Long> offset) {
super(delegate, offset, info);
this(info, delegate, offset, delegate.getStateSize());
}

public InputChannelStateHandle(InputChannelInfo info, StreamStateHandle delegate, List<Long> offset, long size) {
super(delegate, offset, info, size);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ public class ResultSubpartitionStateHandle extends AbstractChannelStateHandle<Re

private static final long serialVersionUID = 1L;

public ResultSubpartitionStateHandle(ResultSubpartitionInfo info, StreamStateHandle delegate, StateContentMetaInfo contentMetaInfo) {
this(info, delegate, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
}

public ResultSubpartitionStateHandle(ResultSubpartitionInfo info, StreamStateHandle delegate, List<Long> offset) {
super(delegate, offset, info);
this(info, delegate, offset, delegate.getStateSize());
}

public ResultSubpartitionStateHandle(ResultSubpartitionInfo info, StreamStateHandle delegate, List<Long> offset, long size) {
super(delegate, offset, info, size);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.stream.IntStream;

import static java.util.Collections.singletonList;
import static org.apache.flink.core.fs.Path.fromLocalFile;
Expand All @@ -63,6 +64,35 @@ public class ChannelStateCheckpointWriterTest {
@Rule
public final TemporaryFolder temporaryFolder = new TemporaryFolder();

@Test
public void testFileHandleSize() throws Exception {
int numChannels = 3;
int numWritesPerChannel = 4;
int numBytesPerWrite = 5;
ChannelStateWriteResult result = new ChannelStateWriteResult();
ChannelStateCheckpointWriter writer = createWriter(
result,
new FsCheckpointStreamFactory(
getSharedInstance(),
fromLocalFile(temporaryFolder.newFolder("checkpointsDir")),
fromLocalFile(temporaryFolder.newFolder("sharedStateDir")),
numBytesPerWrite - 1,
numBytesPerWrite - 1).createCheckpointStateOutputStream(EXCLUSIVE));

InputChannelInfo[] channels = IntStream.range(0, numChannels).mapToObj(i -> new InputChannelInfo(0, i)).toArray(InputChannelInfo[]::new);
for (int call = 0; call < numWritesPerChannel; call++) {
for (int channel = 0; channel < numChannels; channel++) {
write(writer, channels[channel], getData(numBytesPerWrite));
}
}
writer.completeInput();
writer.completeOutput();

for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
assertEquals((Integer.BYTES + numBytesPerWrite) * numWritesPerChannel, handle.getStateSize());
}
}

@Test
@SuppressWarnings("ConstantConditions")
public void testSmallFilesNotWritten() throws Exception {
Expand Down

0 comments on commit f89c606

Please sign in to comment.