diff --git a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowScan.scala b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowScan.scala index f5d8fc0a5..4e6695ea7 100644 --- a/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowScan.scala +++ b/arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowScan.scala @@ -74,27 +74,36 @@ case class ArrowScan( this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) // compute maxSplitBytes - def maxSplitBytes( - sparkSession: SparkSession, - selectedPartitions: Seq[PartitionDirectory]): Long = { + def maxSplitBytes(sparkSession: SparkSession, + selectedPartitions: Seq[PartitionDirectory]): Long = { val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes - val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes + val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes // val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum // .getOrElse(sparkSession.leafNodeDefaultParallelism) val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum .getOrElse(SparkShimLoader.getSparkShims.leafNodeDefaultParallelism(sparkSession)) - val PREFERRED_PARTITION_SIZE_LOWER_BOUND: Long = 128 * 1024 * 1024 - val PREFERRED_PARTITION_SIZE_UPPER_BOUND: Long = 512 * 1024 * 1024 + val PREFERRED_PARTITION_SIZE_LOWER_BOUND: Long = 256 * 1024 * 1024 + val PREFERRED_PARTITION_SIZE_UPPER_BOUND: Long = 1024 * 1024 * 1024 val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum - val bytesPerCore = totalBytes / minPartitionNum - - Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) + var maxBytesPerCore = totalBytes / minPartitionNum + var bytesPerCoreFinal = maxBytesPerCore + var bytesPerCore = maxBytesPerCore + var i = 2 + while (bytesPerCore > PREFERRED_PARTITION_SIZE_UPPER_BOUND && i < 4) { + bytesPerCore = maxBytesPerCore / i + if (bytesPerCore > PREFERRED_PARTITION_SIZE_LOWER_BOUND) { + bytesPerCoreFinal = bytesPerCore + } + i = i + 1 + } + Math.min(PREFERRED_PARTITION_SIZE_UPPER_BOUND, bytesPerCoreFinal) + // Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) } override def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters) // val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) - val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) + val maxSplitBytes = this.maxSplitBytes(sparkSession, selectedPartitions) // val partitionAttributes = fileIndex.partitionSchema.toAttributes val partitionAttributes = ScanUtils.toAttributes(fileIndex) val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap