Skip to content

Commit

Permalink
[CELEBORN-1300] Optimize CelebornInputStreamImpl's memory usage
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
To  avoid too much memory usage when CelebornShuffleReader creates input streams.
This PR does the following:

1. Constructor of `CelebornInputStream` does not fetch chunk
2. `compressedBuf` and `rawDataBuf` are created first time `fillBuffer` is called
3. When `fillBuffer` returns false, which means the inputstream is exhausted, `close` is called and resource released
4. `CelebornFetchFailureSuite` is only run for Spark 3.0 and newer

### Why are the changes needed?
ditto

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
GA and e2e test.

Closes apache#2348 from waitinfuture/1300.

Lead-authored-by: zky.zhoukeyong <[email protected]>
Co-authored-by: Keyong Zhou <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
  • Loading branch information
waitinfuture and waitinfuture committed Mar 5, 2024
1 parent 0285021 commit 8b6bc35
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,20 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private final CelebornConf conf;
private final TransportClientFactory clientFactory;
private final String shuffleKey;
private final PartitionLocation[] locations;
private final int[] attempts;
private PartitionLocation[] locations;
private int[] attempts;
private final int attemptNumber;
private final int startMapIndex;
private final int endMapIndex;

private final Map<Integer, Set<Integer>> batchesRead = new HashMap<>();
private Map<Integer, Set<Integer>> batchesRead = new HashMap<>();

private byte[] compressedBuf;
private byte[] rawDataBuf;
private Decompressor decompressor;

private ByteBuf currentChunk;
private boolean firstChunk = true;
private PartitionReader currentReader;
private final int fetchChunkMaxRetry;
private int fetchChunkRetryCnt = 0;
Expand All @@ -159,14 +160,15 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
private boolean fetchExcludeWorkerOnFailureEnabled;
private boolean shuffleCompressionEnabled;
private long fetchExcludedWorkerExpireTimeout;
private final ConcurrentHashMap<String, Long> fetchExcludedWorkers;
private ConcurrentHashMap<String, Long> fetchExcludedWorkers;

private boolean containLocalRead = false;
private ShuffleClient shuffleClient;
private int appShuffleId;
private int shuffleId;
private int partitionId;
private ExceptionMaker exceptionMaker;
private boolean closed = false;

CelebornInputStreamImpl(
CelebornConf conf,
Expand Down Expand Up @@ -203,16 +205,6 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout();
this.fetchExcludedWorkers = fetchExcludedWorkers;

int bufferSize = conf.clientFetchBufferSize();
if (shuffleCompressionEnabled) {
int headerLen = Decompressor.getCompressionHeaderLength(conf);
bufferSize += headerLen;
compressedBuf = new byte[bufferSize];

decompressor = Decompressor.getDecompressor(conf);
}
rawDataBuf = new byte[bufferSize];

if (conf.clientPushReplicateEnabled()) {
fetchChunkMaxRetry = conf.clientFetchMaxRetriesForEachReplica() * 2;
} else {
Expand All @@ -228,7 +220,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream {
this.shuffleId = shuffleId;
this.shuffleClient = shuffleClient;

moveToNextReader();
moveToNextReader(false);
}

private boolean skipLocation(int startMapIndex, int endMapIndex, PartitionLocation location) {
Expand Down Expand Up @@ -270,7 +262,7 @@ private PartitionLocation nextReadableLocation() {
return currentLocation;
}

private void moveToNextReader() throws IOException {
private void moveToNextReader(boolean fetchChunk) throws IOException {
if (currentReader != null) {
currentReader.close();
currentReader = null;
Expand All @@ -291,7 +283,9 @@ private void moveToNextReader() throws IOException {
currentReader = createReaderWithRetry(currentLocation);
fileIndex++;
}
currentChunk = getNextChunk();
if (fetchChunk) {
currentChunk = getNextChunk();
}
}

private void excludeFailedLocation(PartitionLocation location, Exception e) {
Expand Down Expand Up @@ -517,25 +511,40 @@ public int read(byte[] b, int off, int len) throws IOException {
}

@Override
public void close() {
int locationsCount = locations.length;
logger.debug(
"total location count {} read {} skip {}",
locationsCount,
locationsCount - skipCount.sum(),
skipCount.sum());
if (currentChunk != null) {
logger.debug("Release chunk {}", currentChunk);
currentChunk.release();
currentChunk = null;
}
if (currentReader != null) {
logger.debug("Closing reader");
currentReader.close();
currentReader = null;
}
if (containLocalRead) {
ShuffleClient.printReadStats(logger);
public synchronized void close() {
if (!closed) {
int locationsCount = locations.length;
logger.debug(
"AppShuffleId {}, shuffleId {}, partitionId {}, total location count {}, read {}, skip {}",
appShuffleId,
shuffleId,
partitionId,
locationsCount,
locationsCount - skipCount.sum(),
skipCount.sum());
if (currentChunk != null) {
logger.debug("Release chunk {}", currentChunk);
currentChunk.release();
currentChunk = null;
}
if (currentReader != null) {
logger.debug("Closing reader");
currentReader.close();
currentReader = null;
}
if (containLocalRead) {
ShuffleClient.printReadStats(logger);
}

compressedBuf = null;
rawDataBuf = null;
batchesRead = null;
locations = null;
attempts = null;
decompressor = null;
fetchExcludedWorkers = null;

closed = true;
}
}

Expand All @@ -548,7 +557,7 @@ private boolean moveToNextChunk() throws IOException {
currentChunk = getNextChunk();
return true;
} else if (fileIndex < locations.length) {
moveToNextReader();
moveToNextReader(true);
return currentReader != null;
}
if (currentReader != null) {
Expand All @@ -558,9 +567,27 @@ private boolean moveToNextChunk() throws IOException {
return false;
}

private void init() {
int bufferSize = conf.clientFetchBufferSize();

if (shuffleCompressionEnabled) {
int headerLen = Decompressor.getCompressionHeaderLength(conf);
bufferSize += headerLen;
compressedBuf = new byte[bufferSize];
decompressor = Decompressor.getDecompressor(conf);
}
rawDataBuf = new byte[bufferSize];
}

private boolean fillBuffer() throws IOException {
try {
if (firstChunk && currentReader != null) {
init();
currentChunk = getNextChunk();
firstChunk = false;
}
if (currentChunk == null) {
close();
return false;
}

Expand Down
Loading

0 comments on commit 8b6bc35

Please sign in to comment.