Skip to content
This repository has been archived by the owner on Apr 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #70 from xuechendi/wip_memory
Browse files Browse the repository at this point in the history
[Scala] Optimize Memory management and Track
  • Loading branch information
xuechendi authored May 9, 2020
2 parents 26c766a + 8030c0b commit f84a560
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.storage.pmof.NettyByteBufferPool;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.UnsafeMemoryAllocator;
import sun.nio.ch.FileChannelImpl;
Expand Down Expand Up @@ -54,7 +55,7 @@ public ShuffleBuffer(long length, EqService service, boolean supportNettyBuffer)
this.byteBuffer = convertToByteBuffer();
this.byteBuffer.limit((int)length);
} else {
this.buf = PooledByteBufAllocator.DEFAULT.directBuffer((int) this.length, (int)this.length);
this.buf = NettyByteBufferPool.allocateNewBuffer((int) this.length);
this.address = this.buf.memoryAddress();
this.byteBuffer = this.buf.nioBuffer(0, (int)length);
}
Expand Down Expand Up @@ -135,7 +136,7 @@ public ManagedBuffer close() {
}
} else {
if (this.supportNettyBuffer) {
this.buf.release();
NettyByteBufferPool.releaseBuffer(this.buf);
} else {
unsafeAlloc.free(memoryBlock);
}
Expand Down
13 changes: 9 additions & 4 deletions core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ public class PmemBuffer {
private native long nativeDeletePmemBuffer(long pmBuffer);

private boolean closed = false;
private long len = 0;
long pmBuffer;
PmemBuffer() {
pmBuffer = nativeNewPmemBuffer();
}

PmemBuffer(long len) {
this.len = len;
NettyByteBufferPool.unpooledInc(len);
pmBuffer = nativeNewPmemBufferBySize(len);
}

Expand Down Expand Up @@ -48,6 +51,7 @@ void put(byte[] bytes, int off, int len) {
}

void clean() {
NettyByteBufferPool.unpooledDec(len);
nativeCleanPmemBuffer(pmBuffer);
}

Expand All @@ -60,9 +64,10 @@ long getDirectAddr() {
}

synchronized void close() {
if (!closed) {
nativeDeletePmemBuffer(pmBuffer);
closed = true;
}
if (!closed) {
clean();
nativeDeletePmemBuffer(pmBuffer);
closed = true;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package org.apache.spark.storage.pmof

import java.util.concurrent.atomic.AtomicLong
import io.netty.buffer.{ByteBuf, PooledByteBufAllocator, UnpooledByteBufAllocator}
import scala.collection.mutable.Stack
import java.lang.RuntimeException
import org.apache.spark.internal.Logging

object NettyByteBufferPool extends Logging {
private val allocatedBufRenCnt: AtomicLong = new AtomicLong(0)
private val allocatedBytes: AtomicLong = new AtomicLong(0)
private val peakAllocatedBytes: AtomicLong = new AtomicLong(0)
private val unpooledAllocatedBytes: AtomicLong = new AtomicLong(0)
private var fixedBufferSize: Long = 0
private val allocatedBufferPool: Stack[ByteBuf] = Stack[ByteBuf]()
private var reachRead = false
private val allocator = UnpooledByteBufAllocator.DEFAULT

def allocateNewBuffer(bufSize: Int): ByteBuf = synchronized {
if (fixedBufferSize == 0) {
fixedBufferSize = bufSize
} else if (bufSize > fixedBufferSize) {
throw new RuntimeException(s"allocateNewBuffer, expected size is ${fixedBufferSize}, actual size is ${bufSize}")
}
allocatedBufRenCnt.getAndIncrement()
allocatedBytes.getAndAdd(bufSize)
if (allocatedBytes.get > peakAllocatedBytes.get) {
peakAllocatedBytes.set(allocatedBytes.get)
}
try {
/*if (allocatedBufferPool.isEmpty == false) {
allocatedBufferPool.pop
} else {
allocator.directBuffer(bufSize, bufSize)
}*/
allocator.directBuffer(bufSize, bufSize)
} catch {
case e : Throwable =>
logError(s"allocateNewBuffer size is ${bufSize}")
throw e
}
}

def releaseBuffer(buf: ByteBuf): Unit = synchronized {
allocatedBufRenCnt.getAndDecrement()
allocatedBytes.getAndAdd(0 - fixedBufferSize)
buf.clear()
//allocatedBufferPool.push(buf)
buf.release(buf.refCnt())
}

def unpooledInc(bufSize: Int): Unit = synchronized {
if (reachRead == false) {
reachRead = true
peakAllocatedBytes.set(0)
}
unpooledAllocatedBytes.getAndAdd(bufSize)
}

def unpooledDec(bufSize: Int): Unit = synchronized {
unpooledAllocatedBytes.getAndAdd(0 - bufSize)
}

def unpooledInc(bufSize: Long): Unit = synchronized {
if (reachRead == false) {
reachRead = true
peakAllocatedBytes.set(0)
}
unpooledAllocatedBytes.getAndAdd(bufSize)
}

def unpooledDec(bufSize: Long): Unit = synchronized {
unpooledAllocatedBytes.getAndAdd(0 - bufSize)
}

override def toString(): String = synchronized {
return s"NettyBufferPool [refCnt|allocatedBytes|Peak|Native] is [${allocatedBufRenCnt.get}|${allocatedBytes.get}|${peakAllocatedBytes.get}|${unpooledAllocatedBytes.get}]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s
val serInstance: SerializerInstance = serializer.newInstance()
val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler
var pmemInputStream: PmemInputStream = new PmemInputStream(persistentMemoryWriter, blockId.name)
var inObjStream: DeserializationStream = serInstance.deserializeStream(pmemInputStream)
val wrappedStream = serializerManager.wrapStream(blockId, pmemInputStream)
var inObjStream: DeserializationStream = serInstance.deserializeStream(wrappedStream)

var total_records: Long = 0
var indexInBatch: Int = 0
Expand Down Expand Up @@ -45,6 +46,7 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s
}

def close(): Unit = {
inObjStream.close
pmemInputStream.close
inObjStream = null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ private[spark] class PmemBlockOutputStream(
//persistentMemoryWriter.updateShuffleMeta(blockId.name)

val pmemOutputStream: PmemOutputStream = new PmemOutputStream(
persistentMemoryWriter, numPartitions, blockId.name, numMaps)
persistentMemoryWriter, numPartitions, blockId.name, numMaps, (pmofConf.spill_throttle.toInt + 1024))
val serInstance = serializer.newInstance()
var objStream: SerializationStream = serInstance.serializeStream(pmemOutputStream)
val bs = serializerManager.wrapStream(blockId, pmemOutputStream)
var objStream: SerializationStream = serInstance.serializeStream(bs)

override def write(key: Any, value: Any): Unit = {
objStream.writeKey(key)
Expand All @@ -68,12 +69,16 @@ private[spark] class PmemBlockOutputStream(
}

override def close() {
if (objStream != null) {
objStream.close()
objStream = null
}
pmemOutputStream.close()
objStream = null
}

override def flush() {
objStream.flush()
bs.flush()
}

def maybeSpill(force: Boolean = false): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import java.nio.ByteBuffer
import sun.misc.Cleaner
import io.netty.buffer.Unpooled
import java.util.concurrent.atomic.AtomicInteger
import io.netty.buffer.ByteBuf

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.ManagedBuffer

class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) extends ManagedBuffer with Logging {
var inputStream: InputStream = _
var total_size: Long = -1
var buf: ByteBuf = _
var byteBuffer: ByteBuffer = _
private val refCount = new AtomicInteger(1)

Expand All @@ -26,8 +28,13 @@ class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) ext
// TODO: This function should be Deprecated by spark in near future.
val data_length = size().toInt
val in = createInputStream()
byteBuffer = ByteBuffer.allocateDirect(data_length)
val data = Array.ofDim[Byte](data_length)
if (buf == null) {
buf = NettyByteBufferPool.allocateNewBuffer(data_length)
byteBuffer = buf.nioBuffer(0, data_length)
} else {
byteBuffer.clear()
}
in.read(data)
byteBuffer.put(data)
byteBuffer.flip()
Expand All @@ -48,12 +55,15 @@ class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) ext

override def release(): ManagedBuffer = {
if (refCount.decrementAndGet() == 0) {
if (byteBuffer != null) {
if (buf != null) {
NettyByteBufferPool.releaseBuffer(buf)
}
/*if (byteBuffer != null) {
val cleanerField: java.lang.reflect.Field = byteBuffer.getClass.getDeclaredField("cleaner")
cleanerField.setAccessible(true)
val cleaner: Cleaner = cleanerField.get(byteBuffer).asInstanceOf[Cleaner]
cleaner.clean()
}
}*/
if (inputStream != null) {
inputStream.close()
}
Expand All @@ -62,8 +72,8 @@ class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) ext
}

override def convertToNetty(): Object = {
val data_length = size().toInt
val in = createInputStream()
val data_length = size().toInt
Unpooled.wrappedBuffer(in.asInstanceOf[PmemInputStream].getByteBufferDirectAddr, data_length, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ class PmemOutputStream(
persistentMemoryWriter: PersistentMemoryHandler,
numPartitions: Int,
blockId: String,
numMaps: Int
numMaps: Int,
bufferSize: Int
) extends OutputStream with Logging {
var set_clean = true
var is_closed = false

val length: Int = 1024*1024*6
val length: Int = bufferSize
var bufferFlushedSize: Int = 0
var bufferRemainingSize: Int = 0
val buf: ByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(length, length)
val buf: ByteBuf = NettyByteBufferPool.allocateNewBuffer(length)
val byteBuffer: ByteBuffer = buf.nioBuffer(0, length)

override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
Expand Down Expand Up @@ -60,7 +61,7 @@ class PmemOutputStream(
if (!is_closed) {
flush()
reset()
buf.release()
NettyByteBufferPool.releaseBuffer(buf)
is_closed = true
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,12 @@ private[spark] class PmemExternalSorter[K, V, C](
if (cur_partitionId != partitionId) {
if (cur_partitionId != -1) {
buffer.maybeSpill(true)
buffer.close()
buffer = null
}
cur_partitionId = partitionId
buffer = getPartitionByteBufferArray(dep.shuffleId, cur_partitionId)
logDebug(s"${dep.shuffleId}_${cur_partitionId} ${NettyByteBufferPool}")
}
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
Expand All @@ -115,6 +118,8 @@ private[spark] class PmemExternalSorter[K, V, C](
}
if (buffer != null) {
buffer.maybeSpill(true)
buffer.close()
buffer = null
}
}

Expand Down
14 changes: 8 additions & 6 deletions native/src/PmemBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <cstring>
using namespace std;

#define DEFAULT_BUFSIZE 4096*1024+512
#define DEFAULT_BUFSIZE 2049 * 1024

class PmemBuffer {
public:
Expand All @@ -17,6 +17,7 @@ class PmemBuffer {
pos = 0;
pos_dirty = 0;
}

explicit PmemBuffer(long initial_buf_data_capacity) {
buf_data_capacity = initial_buf_data_capacity;
buf_data = (char*)malloc(sizeof(char) * buf_data_capacity);
Expand All @@ -39,11 +40,11 @@ class PmemBuffer {
std::lock_guard<std::mutex> lock(buffer_mtx);
if (buf_data_capacity == 0 && pmem_data_len > 0) {
buf_data_capacity = pmem_data_len;
buf_data = (char*)malloc(sizeof(char) * pmem_data_len);
buf_data = (char*)malloc(sizeof(char) * buf_data_capacity);
}

if (remaining > 0) {
if (buf_data_capacity < remaining+pmem_data_len) {
if (buf_data_capacity < remaining + pmem_data_len) {
buf_data_capacity = remaining + pmem_data_len;
char* tmp_buf_data = buf_data;
buf_data = (char*)malloc(sizeof(char) * buf_data_capacity);
Expand Down Expand Up @@ -118,6 +119,10 @@ class PmemBuffer {
return read_len;
}

char* getDataAddr() {
return buf_data;
}

int write(char* data, int len) {
std::lock_guard<std::mutex> lock(buffer_mtx);
if (buf_data_capacity == 0) {
Expand Down Expand Up @@ -149,9 +154,6 @@ class PmemBuffer {
return 0;
}

char* getDataAddr() {
return buf_data;
}

private:
mutex buffer_mtx;
Expand Down

0 comments on commit f84a560

Please sign in to comment.