Skip to content

Commit

Permalink
[FLINK-16744][task] Implement channel state reading and writing for u…
Browse files Browse the repository at this point in the history
…naligned checkpoints
  • Loading branch information
rkhachatryan authored and pnowojski committed Apr 10, 2020
1 parent 97a7a10 commit 53d9f36
Show file tree
Hide file tree
Showing 22 changed files with 2,954 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.checkpoint.channel;

import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.CheckpointStreamFactory.CheckpointStateOutputStream;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.RunnableWithException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.concurrent.NotThreadSafe;

import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;

import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;

/**
* Writes channel state for a specific checkpoint-subtask-attempt triple.
*/
@NotThreadSafe
class ChannelStateCheckpointWriter {
private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class);

private final DataOutputStream dataStream;
private final CheckpointStateOutputStream checkpointStream;
private final ChannelStateWriteResult result;
private final Map<InputChannelInfo, List<Long>> inputChannelOffsets = new HashMap<>();
private final Map<ResultSubpartitionInfo, List<Long>> resultSubpartitionOffsets = new HashMap<>();
private final ChannelStateSerializer serializer;
private final long checkpointId;
private boolean allInputsReceived = false;
private boolean allOutputsReceived = false;
private final RunnableWithException onComplete;

ChannelStateCheckpointWriter(
CheckpointStartRequest startCheckpointItem,
CheckpointStreamFactory streamFactory,
ChannelStateSerializer serializer,
RunnableWithException onComplete) throws Exception {
this(
startCheckpointItem.getCheckpointId(),
startCheckpointItem.getTargetResult(),
streamFactory.createCheckpointStateOutputStream(EXCLUSIVE),
serializer,
onComplete);
}

ChannelStateCheckpointWriter(
long checkpointId,
ChannelStateWriteResult result,
CheckpointStateOutputStream stream,
ChannelStateSerializer serializer,
RunnableWithException onComplete) throws Exception {
this(checkpointId, result, serializer, onComplete, stream, new DataOutputStream(stream));
}

ChannelStateCheckpointWriter(
long checkpointId,
ChannelStateWriteResult result,
ChannelStateSerializer serializer,
RunnableWithException onComplete,
CheckpointStateOutputStream checkpointStateOutputStream,
DataOutputStream dataStream) throws Exception {
this.checkpointId = checkpointId;
this.result = checkNotNull(result);
this.checkpointStream = checkNotNull(checkpointStateOutputStream);
this.serializer = checkNotNull(serializer);
this.dataStream = checkNotNull(dataStream);
this.onComplete = checkNotNull(onComplete);
runWithChecks(() -> serializer.writeHeader(dataStream));
}

void writeInput(InputChannelInfo info, Buffer... flinkBuffers) throws Exception {
write(inputChannelOffsets, info, flinkBuffers, !allInputsReceived);
}

void writeOutput(ResultSubpartitionInfo info, Buffer... flinkBuffers) throws Exception {
write(resultSubpartitionOffsets, info, flinkBuffers, !allOutputsReceived);
}

private <K> void write(Map<K, List<Long>> offsets, K key, Buffer[] flinkBuffers, boolean precondition) throws Exception {
try {
if (result.isDone()) {
return;
}
runWithChecks(() -> {
checkState(precondition);
offsets
.computeIfAbsent(key, unused -> new ArrayList<>())
.add(checkpointStream.getPos());
serializer.writeData(dataStream, flinkBuffers);
});
} finally {
for (Buffer flinkBuffer : flinkBuffers) {
flinkBuffer.recycleBuffer();
}
}
}

void completeInput() throws Exception {
LOG.debug("complete input, output completed: {}", allOutputsReceived);
complete(!allInputsReceived, () -> allInputsReceived = true);
}

void completeOutput() throws Exception {
LOG.debug("complete output, input completed: {}", allInputsReceived);
complete(!allOutputsReceived, () -> allOutputsReceived = true);
}

private void complete(boolean precondition, RunnableWithException complete) throws Exception {
if (result.isDone()) {
// likely after abort - only need to set the flag run onComplete callback
doComplete(precondition, complete, onComplete);
} else {
runWithChecks(() -> doComplete(precondition, complete, onComplete, this::finishWriteAndResult));
}
}

private void finishWriteAndResult() throws IOException {
dataStream.flush();
StreamStateHandle underlying = checkpointStream.closeAndGetHandle();
complete(
result.inputChannelStateHandles,
inputChannelOffsets,
(chan, offsets) -> new InputChannelStateHandle(chan, underlying, offsets));
complete(
result.resultSubpartitionStateHandles,
resultSubpartitionOffsets,
(chan, offsets) -> new ResultSubpartitionStateHandle(chan, underlying, offsets));
}

private void doComplete(boolean precondition, RunnableWithException complete, RunnableWithException... callbacks) throws Exception {
Preconditions.checkArgument(precondition);
complete.run();
if (allInputsReceived && allOutputsReceived) {
for (RunnableWithException callback : callbacks) {
callback.run();
}
}
}

private <I, H extends AbstractChannelStateHandle<I>> void complete(
CompletableFuture<Collection<H>> future,
Map<I, List<Long>> offsets,
BiFunction<I, List<Long>, H> buildHandle) {
final Collection<H> handles = new ArrayList<>();
for (Map.Entry<I, List<Long>> e : offsets.entrySet()) {
handles.add(buildHandle.apply(e.getKey(), e.getValue()));
}
future.complete(handles);
LOG.debug("channel state write completed, checkpointId: {}, handles: {}", checkpointId, handles);
}

private void runWithChecks(RunnableWithException r) throws Exception {
try {
checkState(!result.isDone(), "result is already completed", result);
r.run();
} catch (Exception e) {
fail(e);
throw e;
}
}

public void fail(Throwable e) throws Exception {
result.fail(e);
checkpointStream.close();
dataStream.close();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.checkpoint.channel;

import org.apache.flink.annotation.Internal;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.RefCountingFSDataInputStream.RefCountingFSDataInputStreamFactory;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.util.Preconditions;

import org.apache.flink.shaded.guava18.com.google.common.io.Closer;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.concurrent.NotThreadSafe;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

import static java.util.Arrays.asList;
import static org.apache.flink.util.Preconditions.checkState;

/**
* {@link ChannelStateReader} implementation. Usage considerations:
* <ol>
* <li>state of a channel can be read once per instance of this class; once done it returns
* {@link org.apache.flink.runtime.checkpoint.channel.ChannelStateReader.ReadResult#NO_MORE_DATA NO_MORE_DATA}</li>
* <li>reader/writer indices of the passed buffer are respected and updated</li>
* <li>buffers must be prepared (cleared) before passing to reader</li>
* <li>buffers must be released after use</li>
* </ol>
*/
@NotThreadSafe
@Internal
public class ChannelStateReaderImpl implements ChannelStateReader {
private static final Logger log = LoggerFactory.getLogger(ChannelStateReaderImpl.class);

private final Map<InputChannelInfo, ChannelStateStreamReader> inputChannelHandleReaders;
private final Map<ResultSubpartitionInfo, ChannelStateStreamReader> resultSubpartitionHandleReaders;

public ChannelStateReaderImpl(TaskStateSnapshot snapshot) {
this(snapshot, new ChannelStateSerializerImpl());
}

ChannelStateReaderImpl(TaskStateSnapshot snapshot, ChannelStateDeserializer serializer) {
RefCountingFSDataInputStreamFactory streamFactory = new RefCountingFSDataInputStreamFactory(serializer);
final HashMap<InputChannelInfo, ChannelStateStreamReader> inputChannelHandleReadersTmp = new HashMap<>();
final HashMap<ResultSubpartitionInfo, ChannelStateStreamReader> resultSubpartitionHandleReadersTmp = new HashMap<>();
for (Map.Entry<OperatorID, OperatorSubtaskState> e : snapshot.getSubtaskStateMappings()) {
addReaders(inputChannelHandleReadersTmp, e.getValue().getInputChannelState(), streamFactory);
addReaders(resultSubpartitionHandleReadersTmp, e.getValue().getResultSubpartitionState(), streamFactory);
}
inputChannelHandleReaders = inputChannelHandleReadersTmp; // memory barrier to allow another thread call clear()
resultSubpartitionHandleReaders = resultSubpartitionHandleReadersTmp; // memory barrier to allow another thread call clear()
}

private <T> void addReaders(
Map<T, ChannelStateStreamReader> readerMap,
Collection<? extends AbstractChannelStateHandle<T>> handles,
RefCountingFSDataInputStreamFactory streamFactory) {
for (AbstractChannelStateHandle<T> handle : handles) {
checkState(!readerMap.containsKey(handle.getInfo()), "multiple states exist for channel: " + handle.getInfo());
readerMap.put(handle.getInfo(), new ChannelStateStreamReader(handle, streamFactory));
}
}

@Override
public ReadResult readInputData(InputChannelInfo info, Buffer buffer) throws IOException {
log.debug("readInputData, resultSubpartitionInfo: {} , buffer {}", info, buffer);
return getReader(info, inputChannelHandleReaders).readInto(buffer);
}

@Override
public ReadResult readOutputData(ResultSubpartitionInfo info, BufferBuilder bufferBuilder) throws IOException {
log.debug("readOutputData, resultSubpartitionInfo: {} , bufferBuilder {}", info, bufferBuilder);
return getReader(info, resultSubpartitionHandleReaders).readInto(bufferBuilder);
}

private <K> ChannelStateStreamReader getReader(K info, Map<K, ChannelStateStreamReader> readerMap) {
Preconditions.checkArgument(readerMap.containsKey(info), String.format("unknown channel %s. Known channels: %s", info, readerMap.keySet()));
return readerMap.get(info);
}

@Override
public void close() throws Exception {
try (Closer closer = Closer.create()) {
for (Map<?, ChannelStateStreamReader> map : asList(inputChannelHandleReaders, resultSubpartitionHandleReaders)) {
for (ChannelStateStreamReader reader : map.values()) {
closer.register(reader);
}
map.clear();
}
}
}

}
Loading

0 comments on commit 53d9f36

Please sign in to comment.