Skip to content

Commit

Permalink
Added SVE implementation to improve the performance on ARM architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
divya2108 committed Aug 7, 2024
1 parent cc3b56f commit 5194c17
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 7 deletions.
45 changes: 45 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,51 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "OS400")
set(CMAKE_CXX_ARCHIVE_CREATE "<CMAKE_AR> -X64 qc <TARGET> <OBJECTS>")
endif()

if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
include(CheckCSourceCompiles)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve")
check_c_source_compiles("
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
#include <arm_sve.h>
int main() {
svfloat64_t a;
a = svdup_n_f64(0);
return 0;
}
#endif
" COMPILER_HAS_ARM_SVE)

if(COMPILER_HAS_ARM_SVE)
message(STATUS "ARM SVE compiler support detected")
set(SOURCE_CODE "
#include <sys/prctl.h>
int main() {
int ret = prctl(PR_SVE_GET_VL);
return ret >= 0 ? 0 : 1;
}
")
file(WRITE ${CMAKE_BINARY_DIR}/check_sve_support.c "${SOURCE_CODE}")
try_run(RUN_RESULT COMPILE_RESULT
${CMAKE_BINARY_DIR}/check_sve_support_output
${CMAKE_BINARY_DIR}/check_sve_support.c
)

if(RUN_RESULT EQUAL 0)
message(STATUS "ARM SVE hardware support detected")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8-a+sve")
string(APPEND CMAKE_CXX_FLAGS " -DSVE_SUPPORT_DETECTED")
else()
message(STATUS "ARM SVE hardware support not detected")
endif()
else()
message(STATUS "ARM SVE compiler support not detected")
endif()

set(CMAKE_C_FLAGS "${ORIGINAL_CMAKE_C_FLAGS}")
else()
message(STATUS "Not an aarch64 architecture")
endif()

if(USE_NCCL)
find_package(Nccl REQUIRED)
endif()
Expand Down
58 changes: 51 additions & 7 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/**
* Copyright 2017-2023 by XGBoost Contributors
* Copyright 2024 FUJITSU LIMITED
* \file hist_util.cc
*/
#include "hist_util.h"
Expand All @@ -15,6 +16,10 @@
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for SparsePage, SortedCSCPage

#if defined(SVE_SUPPORT_DETECTED)
#include <arm_sve.h> // to leverage sve intrinsics
#endif

#if defined(XGBOOST_MM_PREFETCH_PRESENT)
#include <xmmintrin.h>
#define PREFETCH_READ_T0(addr) _mm_prefetch(reinterpret_cast<const char*>(addr), _MM_HINT_T0)
Expand Down Expand Up @@ -252,13 +257,52 @@ void RowsWiseBuildHistKernel(Span<GradientPair const> gpair, Span<bst_idx_t cons

// The trick with pgh_t buffer helps the compiler to generate faster binary.
const float pgh_t[] = {p_gpair[idx_gh], p_gpair[idx_gh + 1]};
for (size_t j = 0; j < row_size; ++j) {
const uint32_t idx_bin =
two * (static_cast<uint32_t>(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j]));
auto hist_local = hist_data + idx_bin;
*(hist_local) += pgh_t[0];
*(hist_local + 1) += pgh_t[1];
}
#if defined(SVE_SUPPORT_DETECTED)
svfloat64_t pgh_t0_vec = svdup_n_f64(pgh_t[0]);
svfloat64_t pgh_t1_vec = svdup_n_f64(pgh_t[1]);

for (size_t j = 0; j < row_size; j+=svcntw()) {
svbool_t pg32 = svwhilelt_b32(j, row_size);
svbool_t pg64 = svwhilelt_b64(j, row_size);
svuint32_t gr_index_vec =
svld1ub_u32(pg32, reinterpret_cast<const uint8_t *> (&gr_index_local[j]));
svuint32_t offsets_vec = svld1(pg32, &offsets[j]);
svuint32_t idx_bin_vec;
if (kAnyMissing) {
idx_bin_vec = svmul_n_u32_x(pg32, gr_index_vec, two);
} else {
svuint32_t temp = svadd_u32_m(pg32, gr_index_vec, offsets_vec);
idx_bin_vec = svmul_n_u32_x(pg32, temp, two);
}
svuint64_t idx_bin_vec0_0 = svunpklo_u64(idx_bin_vec);
svuint64_t idx_bin_vec0_1 = svunpkhi_u64(idx_bin_vec);
svuint64_t idx_bin_vec1_0 = svadd_n_u64_m(pg64, idx_bin_vec0_0, 1);
svuint64_t idx_bin_vec1_1 = svadd_n_u64_m(pg64, idx_bin_vec0_1, 1);

svfloat64_t hist0_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_0);
svfloat64_t hist0_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec0_1);
svfloat64_t hist1_vec0 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_0);
svfloat64_t hist1_vec1 = svld1_gather_index(pg64, hist_data, idx_bin_vec1_1);

hist0_vec0 = svadd_f64_m(pg64, hist0_vec0, pgh_t0_vec);
hist0_vec1 = svadd_f64_m(pg64, hist0_vec1, pgh_t0_vec);
hist1_vec0 = svadd_f64_m(pg64, hist1_vec0, pgh_t1_vec);
hist1_vec1 = svadd_f64_m(pg64, hist1_vec1, pgh_t1_vec);

svst1_scatter_index(pg64, hist_data, idx_bin_vec0_0, hist0_vec0);
svst1_scatter_index(pg64, hist_data, idx_bin_vec0_1, hist0_vec1);
svst1_scatter_index(pg64, hist_data, idx_bin_vec1_0, hist1_vec0);
svst1_scatter_index(pg64, hist_data, idx_bin_vec1_1, hist1_vec1);
}
#else
for (size_t j = 0; j < row_size; ++j) {
const uint32_t idx_bin =
two * (static_cast<uint32_t>(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j]));
auto hist_local = hist_data + idx_bin;
*(hist_local) += pgh_t[0];
*(hist_local + 1) += pgh_t[1];
}
#endif
}
}

Expand Down

0 comments on commit 5194c17

Please sign in to comment.