diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 66c0092e6..3bf7882b8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,7 @@ include: - /deps-docs.yaml - /deps-rocm.yaml - /deps-vcpkg.yaml + - /deps-windows.yaml - /gpus-rocm.yaml - /rules.yaml @@ -247,6 +248,7 @@ build:benchmark: -S $CI_PROJECT_DIR -G Ninja -D CMAKE_CXX_COMPILER="$AMDCLANG" + -D CMAKE_CXX_FLAGS="-Wall -Wextra -Werror -Wno-#pragma-messages" -D CMAKE_BUILD_TYPE=Release -D BUILD_TEST=OFF -D BUILD_EXAMPLE=OFF @@ -260,6 +262,48 @@ build:benchmark: - $BUILD_DIR/deps/googlebenchmark/ expire_in: 2 weeks +build:windows: + stage: build + needs: [] + extends: + - .rules:build + - .gpus:rocm-windows + - .deps:rocm-windows + - .deps:visual-studio-devshell + parallel: + matrix: + - BUILD_TYPE: + # Disabled due to extensive link times. + # This is tracked in issue 679 + #- Debug + - Release + BUILD_TARGET: + - BENCHMARK + - TEST + script: + - mkdir -p $CI_PROJECT_DIR/build + - cmake -G Ninja + -S $CI_PROJECT_DIR + -B $CI_PROJECT_DIR/build + -D BUILD_$BUILD_TARGET=ON + -D GPU_TARGETS=$GPU_TARGET + -D CMAKE_CXX_COMPILER:PATH="${env:HIP_PATH}\bin\clang++.exe" + -D CMAKE_C_COMPILER:PATH="${env:HIP_PATH}\bin\clang.exe" + -D CMAKE_PREFIX_PATH:PATH="${env:HIP_PATH}" + -D CMAKE_BUILD_TYPE="$BUILD_TYPE" + - cmake --build "$CI_PROJECT_DIR/build" + artifacts: + paths: + - $CI_PROJECT_DIR/build/test/test_* + - $CI_PROJECT_DIR/build/test/rocprim/test_* + - $CI_PROJECT_DIR/build/test/CTestTestfile.cmake + - $CI_PROJECT_DIR/build/test/rocprim/CTestTestfile.cmake + - $CI_PROJECT_DIR/build/gtest/ + - $CI_PROJECT_DIR/build/CMakeCache.txt + - $CI_PROJECT_DIR/build/.ninja_log + - $CI_PROJECT_DIR/build/CTestTestfile.cmake + expire_in: 2 weeks + autotune:build: stage: autotune needs: [] @@ -289,6 +333,19 @@ autotune:build: -D GPU_TARGETS=$GPU_TARGETS - cmake --build . --target $BENCHMARK_TARGETS - 'rm -rf $BUILD_DIR/benchmark/benchmark*.parallel' + # remove benchmark executables if their size together is too large for gitlab ci to handle + - | + total_size_bytes=0 + while read -r file_size; do + total_size_bytes=$((total_size_bytes + file_size)) + done < <(stat --format="%s" benchmark/benchmark*) + total_size_gib="$(numfmt --round=down --to-unit=Gi "$total_size_bytes")" + if [ "$total_size_gib" -ge 3 ]; then + printf "Total size: %s (%d bytes) > 3GiB, skipping benchmark executables from the artifact.\n" \ + "$(numfmt --to=iec-i "$total_size_bytes")" "$total_size_bytes" + rm benchmark/benchmark* + fi + artifacts: paths: - $BUILD_DIR/benchmark/benchmark* @@ -320,6 +377,39 @@ test: --resource-spec-file ./resources.json --parallel $PARALLEL_JOBS +.test-windows-base: + stage: test + extends: + - .deps:rocm-windows + - .gpus:rocm-gpus-windows + - .deps:visual-studio-devshell + - .rules:test + script: + - cd $CI_PROJECT_DIR/build + - ctest --output-on-failure + +# Disabled due to extensive link times. +# This is tracked in issue 679 +# test-windows-debug: +# extends: +# - .test-windows-base +# needs: +# - job: build:windows +# parallel: +# matrix: +# - BUILD_TYPE: Debug +# BUILD_TARGET: TEST + +test-windows-release: + extends: + - .test-windows-base + needs: + - job: build:windows + parallel: + matrix: + - BUILD_TYPE: Release + BUILD_TARGET: TEST + .test-package: script: - cmake @@ -369,6 +459,8 @@ test:deb: test:docs: stage: test + variables: + SPHINX_DIR: $DOCS_DIR/sphinx extends: - .rules:test - .build:docs @@ -472,6 +564,11 @@ autotune:execute-tuning: # On ROCm 5.7 or later, check if this can be removed - the presumption is that the failure is caused by a compiler issue. - > cd "${CI_PROJECT_DIR}" + - | + if [ ! -d "${BUILD_DIR}/benchmark" ]; then + echo "There are no benchmark executables. Run the build job with a BUILD_TARGET." + exit 1 + fi - mkdir -p "${AUTOTUNE_RESULT_DIR}" - python3 .gitlab/run_benchmarks.py diff --git a/CHANGELOG.md b/CHANGELOG.md index baefd45a2..09a057021 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,22 @@ Documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). +## Unreleased rocPRIM-3.2.0 for ROCm 6.2.0 + +### Additions + +* New overloads for `warp_scan::exclusive_scan` that take no initial value. These new overloads will write an unspecified result to the first value of each warp. +* The internal accumulator type of `inclusive_scan(_by_key)` and `exclusive_scan(_by_key)` is now exposed as an optional type parameter. + * The default accumulator type is still the value type of the input iterator (inclusive scan) or the initial value's type (exclusive scan). + This is the same behaviour as before this change. +* New overload for `device_adjacent_difference_inplace` that allows separate input and output iterators, but allows them to point to the same element. + +### Fixes + +* Fixed incorrect results of `warp_exchange::blocked_to_striped_shuffle` and `warp_exchange::striped_to_blocked_shuffle` when the block size is + larger than the logical warp size. The test suite has been updated with such cases. +* Fixed incorrect results returned when calling device `unique_by_key` with overlapping `values_input` and `values_output`. + ## Unreleased rocPRIM-3.1.0 for ROCm 6.1.0 ### Additions diff --git a/CMakeLists.txt b/CMakeLists.txt index b2725db26..b92c8ab8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,12 @@ set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "Install path prefix, prepended # rocPRIM project project(rocprim LANGUAGES CXX) +if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) + set(ROCPRIM_PROJECT_IS_TOP_LEVEL TRUE) +else() + set(ROCPRIM_PROJECT_IS_TOP_LEVEL FALSE) +endif() + #Adding CMAKE_PREFIX_PATH if(WIN32) set(ROCM_ROOT "$ENV{HIP_PATH}" CACHE PATH "Root directory of the ROCm installation") @@ -44,6 +50,7 @@ option(USE_HIP_CPU "Prefer HIP-CPU runtime instead of HW acceleration" OFF) # Disables building tests, benchmarks, examples option(ONLY_INSTALL "Only install" OFF) option(BUILD_CODE_COVERAGE "Build with code coverage enabled" OFF) +option(ROCPRIM_INSTALL "Enable installation of rocPRIM (projects embedding rocPRIM may want to turn this OFF)" ON) # CMake modules list(APPEND CMAKE_MODULE_PATH @@ -94,7 +101,7 @@ endif() # FOR HANDLING ENABLE/DISABLE OPTIONAL BACKWARD COMPATIBILITY for FILE/FOLDER REORG option(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY "Build with file/folder reorg with backward compatibility enabled" OFF) -if(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) +if(ROCPRIM_INSTALL AND BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) rocm_wrap_header_dir( "${PROJECT_SOURCE_DIR}/rocprim/include/rocprim" WRAPPER_LOCATIONS rocprim/include/rocprim @@ -114,7 +121,7 @@ if(USE_HIP_CPU) endif() # Setup VERSION -set(VERSION_STRING "3.1.0") +set(VERSION_STRING "3.2.0") rocm_setup_version(VERSION ${VERSION_STRING}) # Print configuration summary @@ -124,20 +131,24 @@ print_configuration_summary() # rocPRIM library add_subdirectory(rocprim) -if(NOT ONLY_INSTALL AND (BUILD_TEST OR BUILD_BENCHMARK)) +if(ROCPRIM_PROJECT_IS_TOP_LEVEL AND NOT ONLY_INSTALL AND (BUILD_TEST OR BUILD_BENCHMARK)) rocm_package_setup_component(clients) endif() # Tests if(BUILD_TEST AND NOT ONLY_INSTALL) - rocm_package_setup_client_component(tests) + if (ROCPRIM_PROJECT_IS_TOP_LEVEL) + rocm_package_setup_client_component(tests) + endif() enable_testing() add_subdirectory(test) endif() # Benchmarks if(BUILD_BENCHMARK AND NOT ONLY_INSTALL) - rocm_package_setup_client_component(benchmarks) + if (ROCPRIM_PROJECT_IS_TOP_LEVEL) + rocm_package_setup_client_component(benchmarks) + endif() add_subdirectory(benchmark) endif() @@ -147,20 +158,21 @@ if(BUILD_EXAMPLE AND NOT ONLY_INSTALL) endif() # Package -set(BUILD_SHARED_LIBS ON) # Build as though shared library for naming -rocm_package_add_dependencies(DEPENDS "hip-rocclr >= 3.5.0") -set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE.txt") -set(CPACK_RPM_PACKAGE_LICENSE "MIT") - -set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION "\${CPACK_PACKAGING_INSTALL_PREFIX}" ) - -rocm_create_package( - NAME rocprim - DESCRIPTION "Radeon Open Compute Parallel Primitives Library" - MAINTAINER "rocPRIM Maintainer " - HEADER_ONLY -) - +if (ROCPRIM_PROJECT_IS_TOP_LEVEL) + set(BUILD_SHARED_LIBS ON) # Build as though shared library for naming + rocm_package_add_dependencies(DEPENDS "hip-rocclr >= 3.5.0") + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE.txt") + set(CPACK_RPM_PACKAGE_LICENSE "MIT") + + set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION "\${CPACK_PACKAGING_INSTALL_PREFIX}" ) + + rocm_create_package( + NAME rocprim + DESCRIPTION "Radeon Open Compute Parallel Primitives Library" + MAINTAINER "rocPRIM Maintainer " + HEADER_ONLY + ) +endif() # # ADDITIONAL TARGETS FOR CODE COVERAGE diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 499291a39..5bde961ed 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -95,7 +95,9 @@ function(add_rocprim_benchmark BENCHMARK_SOURCE) RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/benchmark" ) - rocm_install(TARGETS ${BENCHMARK_TARGET} COMPONENT benchmarks) + if (ROCPRIM_INSTALL) + rocm_install(TARGETS ${BENCHMARK_TARGET} COMPONENT benchmarks) + endif() if (WIN32 AND NOT DEFINED DLLS_COPIED) set(DLLS_COPIED "YES") set(DLLS_COPIED ${DLLS_COPIED} PARENT_SCOPE) @@ -145,7 +147,8 @@ add_rocprim_benchmark(benchmark_device_run_length_encode.cpp) add_rocprim_benchmark(benchmark_device_scan.cpp) add_rocprim_benchmark(benchmark_device_scan_by_key.cpp) add_rocprim_benchmark(benchmark_device_select.cpp) -add_rocprim_benchmark(benchmark_device_segmented_radix_sort.cpp) +add_rocprim_benchmark(benchmark_device_segmented_radix_sort_keys.cpp) +add_rocprim_benchmark(benchmark_device_segmented_radix_sort_pairs.cpp) add_rocprim_benchmark(benchmark_device_segmented_reduce.cpp) add_rocprim_benchmark(benchmark_device_transform.cpp) add_rocprim_benchmark(benchmark_warp_exchange.cpp) diff --git a/benchmark/ConfigAutotuneSettings.cmake b/benchmark/ConfigAutotuneSettings.cmake index 510c222ad..85556d894 100644 --- a/benchmark/ConfigAutotuneSettings.cmake +++ b/benchmark/ConfigAutotuneSettings.cmake @@ -81,5 +81,17 @@ ${TUNING_TYPES};${LIMITED_TUNING_TYPES};using_warp_scan reduce_then_scan" PARENT set(list_across "\ binary_search upper_bound lower_bound;${TUNING_TYPES};${LIMITED_TUNING_TYPES};64 128 256;1 2 4 8 16" PARENT_SCOPE) set(output_pattern_suffix "@SubAlgorithm@_@ValueType@_@OutputType@_@BlockSize@_@ItemsPerThread@" PARENT_SCOPE) + elseif(file STREQUAL "benchmark_device_segmented_radix_sort_keys") + set(list_across_names "\ +KeyType;BlockSize;ItemsPerThread;PartitionAllowed" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};128 256;4 8 16;false" PARENT_SCOPE) + set(output_pattern_suffix "\ +@KeyType@_@BlockSize@_@ItemsPerThread@_@PartitionAllowed@" PARENT_SCOPE) + elseif(file STREQUAL "benchmark_device_segmented_radix_sort_pairs") + set(list_across_names "\ +KeyType;ValueType;BlockSize;ItemsPerThread;PartitionAllowed" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};int8_t;64;4 8 16;true false" PARENT_SCOPE) + set(output_pattern_suffix "\ +@KeyType@_@ValueType@_@BlockSize@_@ItemsPerThread@_@PartitionAllowed@" PARENT_SCOPE) endif() endfunction() diff --git a/benchmark/benchmark_block_run_length_decode.cpp b/benchmark/benchmark_block_run_length_decode.cpp index 04e1f0428..34adfdc14 100644 --- a/benchmark/benchmark_block_run_length_decode.cpp +++ b/benchmark/benchmark_block_run_length_decode.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -57,7 +57,6 @@ __global__ rocprim::block_load_direct_blocked(global_thread_idx, d_run_items, run_items); rocprim::block_load_direct_blocked(global_thread_idx, d_run_offsets, run_offsets); - ROCPRIM_SHARED_MEMORY typename BlockRunLengthDecodeT::storage_type temp_storage; BlockRunLengthDecodeT block_run_length_decode(run_items, run_offsets); const OffsetT total_decoded_size diff --git a/benchmark/benchmark_device_histogram.parallel.hpp b/benchmark/benchmark_device_histogram.parallel.hpp index 6137ba7ae..146c1cc40 100644 --- a/benchmark/benchmark_device_histogram.parallel.hpp +++ b/benchmark/benchmark_device_histogram.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -358,8 +358,8 @@ struct device_histogram_benchmark_generator template - auto create(std::vector>& storage, - const std::vector& cases) -> + auto create(std::vector>& /*storage*/, + const std::vector& /*cases*/) -> typename std::enable_if::type {} diff --git a/benchmark/benchmark_device_segmented_radix_sort.cpp b/benchmark/benchmark_device_segmented_radix_sort.cpp deleted file mode 100644 index e471f5201..000000000 --- a/benchmark/benchmark_device_segmented_radix_sort.cpp +++ /dev/null @@ -1,521 +0,0 @@ -// MIT License -// -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" -// CmdParser -#include "cmdparser.hpp" -#include "benchmark_utils.hpp" - -// HIP API -#include - -// rocPRIM -#include - -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; -#endif - -namespace rp = rocprim; - -namespace -{ - -constexpr unsigned int warmup_size = 2; -constexpr size_t min_size = 30000; -constexpr std::array segment_counts{ 10, 100, 1000, 2500, 5000, 7500, 10000, 100000 }; -constexpr std::array segment_lengths{30, 256, 3000, 300000}; -} - - -template -void run_sort_keys_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_size, - hipStream_t stream) -{ - using offset_type = int; - using key_type = Key; - - std::vector offsets; - offsets.push_back(0); - - static constexpr int seed = 716; - std::default_random_engine gen(seed); - - std::normal_distribution segment_length_dis(static_cast(mean_segment_length), - 0.1 * mean_segment_length); - - size_t offset = 0; - for(size_t segment_index = 0; segment_index < num_segments;) - { - const double segment_length_candidate = std::round(segment_length_dis(gen)); - if (segment_length_candidate < 0) - { - continue; - } - const offset_type segment_length = static_cast(segment_length_candidate); - offset += segment_length; - offsets.push_back(offset); - ++segment_index; - } - const size_t size = offset; - const size_t segments_count = offsets.size() - 1; - - std::vector keys_input; - if(std::is_floating_point::value) - { - keys_input = get_random_data( - size, - static_cast(-1000), - static_cast(1000) - ); - } - else - { - keys_input = get_random_data( - size, - std::numeric_limits::min(), - std::numeric_limits::max() - ); - } - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } - - offset_type * d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); - HIP_CHECK( - hipMemcpy( - d_offsets, offsets.data(), - offsets.size() * sizeof(offset_type), - hipMemcpyHostToDevice - ) - ); - - key_type * d_keys_input; - key_type * d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK( - hipMemcpy( - d_keys_input, keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - - void * d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK( - rp::segmented_radix_sort_keys( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_keys( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_keys( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); -} - -template -void run_sort_pairs_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_size, - hipStream_t stream) -{ - using offset_type = int; - using key_type = Key; - using value_type = Value; - - // Generate data - std::vector offsets; - offsets.push_back(0); - - static constexpr int seed = 716; - std::default_random_engine gen(seed); - - std::normal_distribution segment_length_dis(static_cast(mean_segment_length), - 0.1 * mean_segment_length); - - size_t offset = 0; - for(size_t segment_index = 0; segment_index < num_segments;) - { - const double segment_length_candidate = std::round(segment_length_dis(gen)); - if (segment_length_candidate < 0) - { - continue; - } - const offset_type segment_length = static_cast(segment_length_candidate); - offset += segment_length; - offsets.push_back(offset); - ++segment_index; - } - const size_t size = offset; - const size_t segments_count = offsets.size() - 1; - - std::vector keys_input; - if(std::is_floating_point::value) - { - keys_input = get_random_data( - size, - static_cast(-1000), - static_cast(1000) - ); - } - else - { - keys_input = get_random_data( - size, - std::numeric_limits::min(), - std::numeric_limits::max() - ); - } - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } - - std::vector values_input(size); - std::iota(values_input.begin(), values_input.end(), 0); - - offset_type * d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, (segments_count + 1) * sizeof(offset_type))); - HIP_CHECK( - hipMemcpy( - d_offsets, offsets.data(), - (segments_count + 1) * sizeof(offset_type), - hipMemcpyHostToDevice - ) - ); - - key_type * d_keys_input; - key_type * d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK( - hipMemcpy( - d_keys_input, keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - - value_type * d_values_input; - value_type * d_values_output; - HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); - HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); - HIP_CHECK( - hipMemcpy( - d_values_input, values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice - ) - ); - - void * d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK( - rp::segmented_radix_sort_pairs( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, d_values_input, d_values_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_pairs( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, d_values_input, d_values_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_pairs( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, d_values_input, d_values_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed( - state.iterations() * batch_size * size * (sizeof(key_type) + sizeof(value_type)) - ); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values_output)); -} - -template -void add_sort_keys_benchmarks(std::vector &benchmarks, - hipStream_t stream, - size_t max_size, - size_t min_size, - size_t target_size) -{ - std::string key_name = Traits::name(); - std::string value_name = Traits::name(); - for(const auto segment_count : segment_counts) - { - for(const auto segment_length : segment_lengths) - { - const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > max_size || number_of_elements < min_size) - { - continue; - } - benchmarks.push_back(benchmark::RegisterBenchmark( - bench_naming::format_name( - "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" - + value_name + ",segment_count:" + std::to_string(segment_count) - + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") - .c_str(), - [=](benchmark::State& state) { - run_sort_keys_benchmark(state, - segment_count, - segment_length, - target_size, - stream); - })); - } - } -} - -template -void add_sort_pairs_benchmarks(std::vector &benchmarks, - hipStream_t stream, - size_t max_size, - size_t min_size, - size_t target_size) -{ - std::string key_name = Traits::name(); - std::string value_name = Traits::name(); - for(const auto segment_count : segment_counts) - { - for(const auto segment_length : segment_lengths) - { - const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > max_size || number_of_elements < min_size) - { - continue; - } - benchmarks.push_back(benchmark::RegisterBenchmark( - bench_naming::format_name( - "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" - + value_name + ",segment_count:" + std::to_string(segment_count) - + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") - .c_str(), - [=](benchmark::State& state) - { - run_sort_pairs_benchmark(state, - segment_count, - segment_length, - target_size, - stream); - })); - } - } -} - -int main(int argc, char *argv[]) -{ - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - - // Add benchmarks - std::vector benchmarks; - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - - using custom_float2 = custom_type; - using custom_double2 = custom_type; - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; -} diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp new file mode 100644 index 000000000..dfd7e14e9 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp @@ -0,0 +1,322 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include + +// Google Benchmark +#include "benchmark/benchmark.h" +// CmdParser +#include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +// HIP API +#include + +// rocPRIM +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 32; +#endif + +namespace rp = rocprim; + +namespace +{ + +constexpr unsigned int warmup_size = 2; +constexpr size_t min_size = 30000; +constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; +constexpr std::array segment_lengths{30, 256, 3000, 300000}; +} // namespace + +// This benchmark only handles the rocprim::segmented_radix_sort_keys function. The benchmark was separated into two (keys and pairs), +// because the binary became too large to link. Runs into a "relocation R_X86_64_PC32 out of range" error. +// This happens partially, because of the algorithm has 4 kernels, and decides at runtime which one to call. + +template +void run_sort_keys_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) +{ + using offset_type = int; + using key_type = Key; + + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis(static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + offsets.size() * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK( + hipMemcpy(d_keys_input, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rp::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); +} + +template +void add_sort_keys_benchmarks(std::vector& benchmarks, + hipStream_t stream, + size_t max_size, + size_t min_size, + size_t target_size) +{ + std::string key_name = Traits::name(); + std::string value_name = Traits::name(); + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > max_size || number_of_elements < min_size) + { + continue; + } + benchmarks.push_back(benchmark::RegisterBenchmark( + bench_naming::format_name( + "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" + + value_name + ",segment_count:" + std::to_string(segment_count) + + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") + .c_str(), + [=](benchmark::State& state) { + run_sort_keys_benchmark(state, + segment_count, + segment_length, + target_size, + stream); + })); + } + } +} + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + +#ifdef BENCHMARK_CONFIG_TUNING + // optionally run an evenly split subset of benchmarks, when making multiple program invocations + parser.set_optional("parallel_instance", + "parallel_instance", + 0, + "parallel instance index"); + parser.set_optional("parallel_instances", + "parallel_instances", + 1, + "total parallel instances"); +#endif + + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + + // Add benchmarks + std::vector benchmarks; +#ifdef BENCHMARK_CONFIG_TUNING + const int parallel_instance = parser.get("parallel_instance"); + const int parallel_instances = parser.get("parallel_instances"); + config_autotune_register::register_benchmark_subset(benchmarks, + parallel_instance, + parallel_instances, + size, + stream); +#else + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); +#endif + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in new file mode 100644 index 000000000..4913fdff7 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in @@ -0,0 +1,34 @@ +// MIT License +// +// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "benchmark_device_segmented_radix_sort_keys.parallel.hpp" +#include "benchmark_utils.hpp" + +namespace +{ +auto benchmarks = config_autotune_register::create_bulk(device_segmented_radix_sort_benchmark_generator<@BlockSize@, + @ItemsPerThread@, + @KeyType@, + @PartitionAllowed@>::create); +} // namespace diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp new file mode 100644 index 000000000..4227d223a --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp @@ -0,0 +1,373 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ + +#include +#include +#include + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include "benchmark_utils.hpp" + +template +std::string warp_sort_config_name(T const& warp_sort_config) +{ + return "{pa:" + std::to_string(warp_sort_config.partitioning_allowed) + + ",lwss:" + std::to_string(warp_sort_config.logical_warp_size_small) + + ",ipts:" + std::to_string(warp_sort_config.items_per_thread_small) + + ",bss:" + std::to_string(warp_sort_config.block_size_small) + + ",pt:" + std::to_string(warp_sort_config.partitioning_threshold) + + ",lwsm:" + std::to_string(warp_sort_config.logical_warp_size_medium) + + ",iptm:" + std::to_string(warp_sort_config.items_per_thread_medium) + + ",bsm:" + std::to_string(warp_sort_config.block_size_medium) + "}"; +} + +template +std::string config_name() +{ + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return "{bs:" + std::to_string(config.kernel_config.block_size) + + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) + + ",lrb:" + std::to_string(config.long_radix_bits) + + ",srb:" + std::to_string(config.short_radix_bits) + + ",eupws:" + std::to_string(config.enable_unpartitioned_warp_sort) + + ",wsc:" + warp_sort_config_name(config.warp_sort_config) + "}"; +} + +template<> +inline std::string config_name() +{ + return "default_config"; +} + +template +struct device_segmented_radix_sort_benchmark : public config_autotune_interface +{ + std::string name() const override + { + using namespace std::string_literals; + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return bench_naming::format_name( + "{lvl:device,algo:segmented_radix_sort,key_type:" + std::string(Traits::name()) + + ",value_type:empty_type" + ",cfg:" + config_name() + "}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) const + { + using offset_type = int; + using key_type = Key; + + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis( + static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + offsets.size() * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + } + + void run(benchmark::State& state, size_t size, hipStream_t stream) const override + { + constexpr std::array + segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + constexpr std::array segment_lengths{30, 256, 3000, 300000}; + + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > 33554432 || number_of_elements < 300000) + { + continue; + } + + run_benchmark(state, segment_count, segment_length, size, stream); + } + } + } +}; + +template class T, bool enable, Tp... Idx> +struct decider; +template +struct device_segmented_radix_sort_benchmark_generator +{ + template + struct create_lrb + { + template + struct create_srb + { + template + struct create_euws + { + template + struct create_lwss + { + template + struct create_pt + { + void operator()( + std::vector>& storage) + { + storage.emplace_back( + std::make_unique>>>()); + } + }; + + void + operator()(std::vector>& storage) + { + static_for_each, create_pt>(storage); + } + }; + + void operator()(std::vector>& storage) + { + if(PartitionAllowed) + { + + static_for_each, + create_lwss>(storage); + } + else + { + storage.emplace_back(std::make_unique>>()); + } + } + }; + + void operator()(std::vector>& storage) + { + decider::do_the_thing( + storage); + } + }; + + void operator()(std::vector>& storage) + { + decider::do_the_thing( + storage); + } + }; + + static void create(std::vector>& storage) + { + static_for_each, create_lrb>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& storage) + { + static_for_each, T>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& /*storage*/) + {} +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp new file mode 100644 index 000000000..4aba2b436 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp @@ -0,0 +1,357 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include + +// Google Benchmark +#include "benchmark/benchmark.h" +// CmdParser +#include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +// HIP API +#include + +// rocPRIM +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 32; +#endif + +namespace rp = rocprim; + +namespace +{ + +constexpr unsigned int warmup_size = 2; +constexpr size_t min_size = 30000; +constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; +constexpr std::array segment_lengths{30, 256, 3000, 300000}; +} // namespace + +// This benchmark only handles the rocprim::segmented_radix_sort_pairs function. The benchmark was separated into two (keys and pairs), +// because the binary became too large to link. Runs into a "relocation R_X86_64_PC32 out of range" error. +// This happens partially, because of the algorithm has 4 kernels, and decides at runtime which one to call. + +template +void run_sort_pairs_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) +{ + using offset_type = int; + using key_type = Key; + using value_type = Value; + + // Generate data + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis(static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + std::vector values_input(size); + std::iota(values_input.begin(), values_input.end(), 0); + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, (segments_count + 1) * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + (segments_count + 1) * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK( + hipMemcpy(d_keys_input, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); + + value_type* d_values_input; + value_type* d_values_output; + HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); + HIP_CHECK(hipMemcpy(d_values_input, + values_input.data(), + size * sizeof(value_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rp::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size + * (sizeof(key_type) + sizeof(value_type))); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_values_output)); +} + +template +void add_sort_pairs_benchmarks(std::vector& benchmarks, + hipStream_t stream, + size_t max_size, + size_t min_size, + size_t target_size) +{ + std::string key_name = Traits::name(); + std::string value_name = Traits::name(); + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > max_size || number_of_elements < min_size) + { + continue; + } + benchmarks.push_back(benchmark::RegisterBenchmark( + bench_naming::format_name( + "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" + + value_name + ",segment_count:" + std::to_string(segment_count) + + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") + .c_str(), + [=](benchmark::State& state) + { + run_sort_pairs_benchmark(state, + segment_count, + segment_length, + target_size, + stream); + })); + } + } +} + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + +#ifdef BENCHMARK_CONFIG_TUNING + // optionally run an evenly split subset of benchmarks, when making multiple program invocations + parser.set_optional("parallel_instance", + "parallel_instance", + 0, + "parallel instance index"); + parser.set_optional("parallel_instances", + "parallel_instances", + 1, + "total parallel instances"); +#endif + + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + + // Add benchmarks + std::vector benchmarks; +#ifdef BENCHMARK_CONFIG_TUNING + const int parallel_instance = parser.get("parallel_instance"); + const int parallel_instances = parser.get("parallel_instances"); + config_autotune_register::register_benchmark_subset(benchmarks, + parallel_instance, + parallel_instances, + size, + stream); +#else + using custom_float2 = custom_type; + using custom_double2 = custom_type; + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, + stream, + size, + min_size, + size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, + stream, + size, + min_size, + size / 2); +#endif + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in new file mode 100644 index 000000000..55fc0a849 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in @@ -0,0 +1,35 @@ +// MIT License +// +// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "benchmark_device_segmented_radix_sort_pairs.parallel.hpp" +#include "benchmark_utils.hpp" + +namespace +{ +auto benchmarks = config_autotune_register::create_bulk(device_segmented_radix_sort_benchmark_generator<@BlockSize@, + @ItemsPerThread@, + @KeyType@, + @ValueType@, + @PartitionAllowed@>::create); +} // namespace diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp new file mode 100644 index 000000000..917af3a25 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp @@ -0,0 +1,412 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ + +#include +#include +#include + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include "benchmark_utils.hpp" + +template +std::string warp_sort_config_name(T const& warp_sort_config) +{ + return "{pa:" + std::to_string(warp_sort_config.partitioning_allowed) + + ",lwss:" + std::to_string(warp_sort_config.logical_warp_size_small) + + ",ipts:" + std::to_string(warp_sort_config.items_per_thread_small) + + ",bss:" + std::to_string(warp_sort_config.block_size_small) + + ",pt:" + std::to_string(warp_sort_config.partitioning_threshold) + + ",lwsm:" + std::to_string(warp_sort_config.logical_warp_size_medium) + + ",iptm:" + std::to_string(warp_sort_config.items_per_thread_medium) + + ",bsm:" + std::to_string(warp_sort_config.block_size_medium) + "}"; +} + +template +std::string config_name() +{ + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return "{bs:" + std::to_string(config.kernel_config.block_size) + + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) + + ",lrb:" + std::to_string(config.long_radix_bits) + + ",srb:" + std::to_string(config.short_radix_bits) + + ",eupws:" + std::to_string(config.enable_unpartitioned_warp_sort) + + ",wsc:" + warp_sort_config_name(config.warp_sort_config) + "}"; +} + +template<> +inline std::string config_name() +{ + return "default_config"; +} + +template +struct device_segmented_radix_sort_benchmark : public config_autotune_interface +{ + std::string name() const override + { + using namespace std::string_literals; + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return bench_naming::format_name("{lvl:device,algo:segmented_radix_sort,key_type:" + + std::string(Traits::name()) + + ",value_type:" + std::string(Traits::name()) + + ",cfg:" + config_name() + "}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) const + { + using offset_type = int; + using key_type = Key; + using value_type = Value; + + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis( + static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + std::vector values_input; + if(std::is_floating_point::value) + { + values_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + values_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + offsets.size() * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + value_type* d_values_input; + value_type* d_values_output; + HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); + HIP_CHECK(hipMemcpy(d_values_input, + values_input.data(), + size * sizeof(value_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_values_output)); + } + + void run(benchmark::State& state, size_t size, hipStream_t stream) const override + { + constexpr std::array + segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + constexpr std::array segment_lengths{30, 256, 3000, 300000}; + + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > 33554432 || number_of_elements < 300000) + { + continue; + } + + run_benchmark(state, segment_count, segment_length, size, stream); + } + } + } +}; + +template class T, bool enable, Tp... Idx> +struct decider; +template +struct device_segmented_radix_sort_benchmark_generator +{ + template + struct create_lrb + { + template + struct create_srb + { + template + struct create_euws + { + template + struct create_lwss + { + template + struct create_pt + { + void operator()( + std::vector>& storage) + { + storage.emplace_back( + std::make_unique>>>()); + } + }; + + void + operator()(std::vector>& storage) + { + static_for_each, create_pt>(storage); + } + }; + + void operator()(std::vector>& storage) + { + if(PartitionAllowed) + { + + static_for_each, + create_lwss>(storage); + } + else + { + storage.emplace_back(std::make_unique>>()); + } + } + }; + + void operator()(std::vector>& storage) + { + static_for_each, create_euws>(storage); + } + }; + + void operator()(std::vector>& storage) + { + decider::do_the_thing( + storage); + } + }; + + static void create(std::vector>& storage) + { + static_for_each, create_lrb>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& storage) + { + static_for_each, T>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& /*storage*/) + {} +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ diff --git a/benchmark/benchmark_utils.hpp b/benchmark/benchmark_utils.hpp index 2d67ff3fa..e5863e758 100644 --- a/benchmark/benchmark_utils.hpp +++ b/benchmark/benchmark_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -306,12 +306,8 @@ inline bool is_warp_size_supported(const unsigned int required_warp_size, const } template -struct DeviceSelectWarpSize -{ - static constexpr unsigned int value = ::rocprim::device_warp_size() >= LogicalWarpSize - ? LogicalWarpSize - : ::rocprim::device_warp_size(); -}; +__device__ constexpr bool device_test_enabled_for_warp_size_v + = ::rocprim::device_warp_size() >= LogicalWarpSize; template std::vector diff --git a/benchmark/benchmark_warp_exchange.cpp b/benchmark/benchmark_warp_exchange.cpp index 997c2e357..64a0c65a5 100644 --- a/benchmark/benchmark_warp_exchange.cpp +++ b/benchmark/benchmark_warp_exchange.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -124,16 +124,14 @@ struct ScatterToStripedOp } }; -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op -> -__global__ -__launch_bounds__(BlockSize) -auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::enable_if::value, void>::type +template +__device__ auto warp_exchange_benchmark(T* d_output, unsigned int trials) + -> std::enable_if_t + && !std::is_same::value> { T thread_data[ItemsPerThread]; @@ -141,16 +139,12 @@ auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::ena for(unsigned int i = 0; i < ItemsPerThread; i++) { // generate unique value each data-element - thread_data[i] = static_cast(hipThreadIdx_x*ItemsPerThread+i); + thread_data[i] = static_cast(threadIdx.x * ItemsPerThread + i); } - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - DeviceSelectWarpSize::value - >; + using warp_exchange_type = ::rocprim::warp_exchange; constexpr unsigned int warps_in_block = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[warps_in_block]; ROCPRIM_NO_UNROLL @@ -163,44 +157,37 @@ auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::ena ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - const unsigned int global_idx = - (BlockSize * hipBlockIdx_x + hipThreadIdx_x) * ItemsPerThread + i; + const unsigned int global_idx = (BlockSize * blockIdx.x + threadIdx.x) * ItemsPerThread + i; d_output[global_idx] = thread_data[i]; } } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op - > -__global__ -__launch_bounds__(BlockSize) - auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::enable_if::value, void>::type +template +__device__ auto warp_exchange_benchmark(T* d_output, unsigned int trials) + -> std::enable_if_t + && std::is_same::value> { T thread_data[ItemsPerThread]; unsigned int thread_ranks[ItemsPerThread]; constexpr unsigned int warps_in_block = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; - const unsigned int lane_id = hipThreadIdx_x % LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; + const unsigned int lane_id = threadIdx.x % LogicalWarpSize; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { // generate unique value each data-element - thread_data[i] = static_cast(hipThreadIdx_x*ItemsPerThread+i); + thread_data[i] = static_cast(threadIdx.x * ItemsPerThread + i); // generate unique destination location for each data-element const unsigned int s_lane_id = i % 2 == 0 ? LogicalWarpSize - 1 - lane_id : lane_id; thread_ranks[i] = s_lane_id*ItemsPerThread+i; // scatter values in warp across whole storage } - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - DeviceSelectWarpSize::value - >; + using warp_exchange_type = ::rocprim::warp_exchange; ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[warps_in_block]; ROCPRIM_NO_UNROLL @@ -213,12 +200,30 @@ __launch_bounds__(BlockSize) ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - const unsigned int global_idx = - (BlockSize * hipBlockIdx_x + hipThreadIdx_x) * ItemsPerThread + i; + const unsigned int global_idx = (BlockSize * blockIdx.x + threadIdx.x) * ItemsPerThread + i; d_output[global_idx] = thread_data[i]; } } +template +__device__ auto warp_exchange_benchmark(T* /*d_output*/, unsigned int /*trials*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_exchange_kernel(T* d_output, unsigned int trials) +{ + warp_exchange_benchmark(d_output, trials); +} + template< class T, unsigned int BlockSize, @@ -245,18 +250,8 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) // Record start event HIP_CHECK(hipEventRecord(start, stream)); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(warp_exchange_kernel< - T, - BlockSize, - ItemsPerThread, - LogicalWarpSize, - Op - > - ), - dim3(size / items_per_block), dim3(BlockSize), 0, stream, - d_output, trials - ); + warp_exchange_kernel + <<>>(d_output, trials); HIP_CHECK(hipPeekAtLastError()); diff --git a/benchmark/benchmark_warp_scan.cpp b/benchmark/benchmark_warp_scan.cpp index e9ba06674..daee015ae 100644 --- a/benchmark/benchmark_warp_scan.cpp +++ b/benchmark/benchmark_warp_scan.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -179,9 +179,8 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t size) template void add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - hipStream_t stream, - size_t size) + hipStream_t stream, + size_t size) { using custom_double2 = custom_type; using custom_int_double = custom_type; @@ -226,8 +225,8 @@ int main(int argc, char *argv[]) // Add benchmarks std::vector benchmarks; - add_benchmarks(benchmarks, "inclusive", stream, size); - add_benchmarks(benchmarks, "exclusive", stream, size); + add_benchmarks(benchmarks, stream, size); //inclusive + add_benchmarks(benchmarks, stream, size); //exclusive // Use manual timing for(auto& b : benchmarks) diff --git a/docs/conf.py b/docs/conf.py index d70124f3c..58c7445ed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,23 @@ +# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full @@ -6,17 +26,14 @@ import re -from rocm_docs import ROCmDocs - with open('../CMakeLists.txt', encoding='utf-8') as f: match = re.search(r'.*\bset\(VERSION_STRING\s+\"?([0-9.]+)[^0-9.]+', f.read()) if not match: raise ValueError("VERSION not found!") version_number = match[1] -left_nav_title = f"rocPRIM {version_number} Documentation" -# for PDF output on Read the Docs project = "rocPRIM Documentation" +html_title = f"rocPRIM {version_number} Documentation" author = "Advanced Micro Devices, Inc." copyright = "Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved." version = version_number @@ -24,14 +41,18 @@ external_toc_path = "./sphinx/_toc.yml" -docs_core = ROCmDocs(left_nav_title) -docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") -docs_core.setup() +extensions = ["rocm_docs", "rocm_docs.doxygen"] -external_projects_current_project = "rocprim" +html_theme = "rocm_docs_theme" -for sphinx_var in ROCmDocs.SPHINX_VARS: - globals()[sphinx_var] = getattr(docs_core, sphinx_var) +doxygen_root = "doxygen" +doxygen_project = { + "name": "rocPRIM", + "path": "doxygen/xml", +} + +external_projects = [] +external_projects_current_project = "rocprim" cpp_id_attributes = ["__global__", "__device__", "__host__", "__forceinline__", "static"] cpp_paren_attributes = ["__declspec"] diff --git a/docs/device_ops/adjacent_difference.rst b/docs/device_ops/adjacent_difference.rst index 987614cd4..670d5783f 100644 --- a/docs/device_ops/adjacent_difference.rst +++ b/docs/device_ops/adjacent_difference.rst @@ -23,6 +23,11 @@ left, inplace .. doxygenfunction:: rocprim::adjacent_difference_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt values, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) +left, aliased +~~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::adjacent_difference_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt input, const OutputIt output, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) + right ============= @@ -33,3 +38,8 @@ right, inplace .. doxygenfunction:: rocprim::adjacent_difference_right_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt values, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) +right, aliased +~~~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::adjacent_difference_right_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt input, const OutputIt output, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) + diff --git a/docs/device_ops/index.rst b/docs/device_ops/index.rst index b75bb2184..1f01a7414 100644 --- a/docs/device_ops/index.rst +++ b/docs/device_ops/index.rst @@ -21,3 +21,4 @@ * :ref:`dev-adjacent_difference` * :ref:`dev-binary_search` * :ref:`dev-histogram` + * :ref:`dev-memcpy` diff --git a/docs/device_ops/memcpy.rst b/docs/device_ops/memcpy.rst new file mode 100644 index 000000000..8330daed5 --- /dev/null +++ b/docs/device_ops/memcpy.rst @@ -0,0 +1,19 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-memcpy: + + +Memcpy +------ + +Configuring the kernel +~~~~~~~~~~~~~~~~~~~~~~ + +.. doxygenstruct:: rocprim::batch_memcpy_config + +batch_memcpy +~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::batch_memcpy(void* temporary_storage, size_t& storage_size, InputBufferItType sources, OutputBufferItType destinations, BufferSizeItType sizes, uint32_t num_copies, hipStream_t stream = hipStreamDefault, bool debug_synchronous = false) diff --git a/docs/device_ops/partition.rst b/docs/device_ops/partition.rst index a10d95920..ecd5636a9 100644 --- a/docs/device_ops/partition.rst +++ b/docs/device_ops/partition.rst @@ -13,6 +13,11 @@ partition .. doxygenfunction:: rocprim::partition(void *temporary_storage, size_t &storage_size, InputIterator input, OutputIterator output, SelectedCountOutputIterator selected_count_output, const size_t size, UnaryPredicate predicate, const hipStream_t stream=0, const bool debug_synchronous=false) +partition_two_way +~~~~~~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::partition_two_way(void* temporary_storage, size_t& storage_size, InputIterator input, SelectedOutputIterator output_selected, RejectedOutputIterator output_rejected, SelectedCountOutputIterator selected_count_output, const size_t size, Predicate predicate, const hipStream_t stream = 0, const bool debug_synchronous = false) + partition_three_way ====================== diff --git a/docs/reference/intrinsics.rst b/docs/reference/intrinsics.rst index 0aac48f55..5f09f174f 100644 --- a/docs/reference/intrinsics.rst +++ b/docs/reference/intrinsics.rst @@ -52,4 +52,6 @@ Active threads ================== .. doxygenfunction:: rocprim::ballot (int predicate) +.. doxygenfunction:: rocprim::group_elect(lane_mask_type mask) .. doxygenfunction:: rocprim::masked_bit_count (lane_mask_type x, unsigned int add=0) +.. doxygenfunction:: rocprim::match_any(unsigned int label, bool valid = true) diff --git a/docs/reference/ops_summary.rst b/docs/reference/ops_summary.rst index 01d7238fc..95cfa3f93 100644 --- a/docs/reference/ops_summary.rst +++ b/docs/reference/ops_summary.rst @@ -44,8 +44,9 @@ Partition/Merge Data Movement =============== -* ``store`` stores the sequence to a continuous memory zone. There are variations to use an optimized path or to specify how to store the sequence to better fit the access patterns of the CUs -* ``load`` the complementary operations of the above ones +* ``store`` stores the sequence to a continuous memory zone. There are variations to use an optimized path or to specify how to store the sequence to better fit the access patterns of the CUs. +* ``load`` the complementary operations of the above ones. +* ``memcpy`` copies bytes between device sources and destinations Other operations ====================== diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index f0236fd48..07afd927c 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -23,13 +23,14 @@ subtrees: - file: device_ops/sort.rst - file: device_ops/merge.rst - file: device_ops/partition.rst - - file: device_ops/run_lenght_encoding.rst + - file: device_ops/run_length_encoding.rst - file: device_ops/scan.rst - file: device_ops/select.rst - file: device_ops/reduce.rst - file: device_ops/adjacent_difference.rst - file: device_ops/binary_search.rst - file: device_ops/histogram.rst + - file: device_ops/memcpy.rst - file: block_ops/index.rst subtrees: - entries: @@ -60,6 +61,5 @@ subtrees: - file: reference/thread_ops.rst - file: reference/iterators.rst - file: reference/intrinsics.rst - - file: reference/reorder.rst - file: reference/acknowledge.rst - file: license.rst diff --git a/rocprim/CMakeLists.txt b/rocprim/CMakeLists.txt index 2b02c7f2f..037997d7d 100644 --- a/rocprim/CMakeLists.txt +++ b/rocprim/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,7 @@ configure_file( @ONLY ) -if (BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) +if (ROCPRIM_INSTALL AND BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) rocm_wrap_header_file( "rocprim_version.hpp" WRAPPER_LOCATIONS rocprim/include/rocprim @@ -54,35 +54,36 @@ target_link_libraries(rocprim_hip INTERFACE rocprim hip::device) # Installation +if (ROCPRIM_INSTALL) + # We need to install headers manually as rocm_install_targets + # does not support header-only libraries (INTERFACE targets) + rocm_install_targets( + TARGETS rocprim rocprim_hip + ) -# We need to install headers manually as rocm_install_targets -# does not support header-only libraries (INTERFACE targets) -rocm_install_targets( - TARGETS rocprim rocprim_hip -) + rocm_install( + DIRECTORY + "include/" + "${PROJECT_BINARY_DIR}/rocprim/include/" + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.hpp" + PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ + ) -rocm_install( - DIRECTORY - "include/" - "${PROJECT_BINARY_DIR}/rocprim/include/" - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ - FILES_MATCHING - PATTERN "*.h" - PATTERN "*.hpp" - PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ -) + if (BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) + rocm_install( + DIRECTORY + "${PROJECT_BINARY_DIR}/rocprim/wrapper/" + DESTINATION rocprim/ ) + endif() -if (BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) -rocm_install( - DIRECTORY - "${PROJECT_BINARY_DIR}/rocprim/wrapper/" - DESTINATION rocprim/ ) + # Export targets + rocm_export_targets( + TARGETS roc::rocprim roc::rocprim_hip + DEPENDS PACKAGE hip + NAMESPACE roc:: + ) endif() - -# Export targets -rocm_export_targets( - TARGETS roc::rocprim roc::rocprim_hip - DEPENDS PACKAGE hip - NAMESPACE roc:: -) diff --git a/rocprim/include/rocprim/detail/match_result_type.hpp b/rocprim/include/rocprim/detail/match_result_type.hpp deleted file mode 100644 index 3add1f1a2..000000000 --- a/rocprim/include/rocprim/detail/match_result_type.hpp +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2018-2021 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#ifndef ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ -#define ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ - -#include - -#include "../config.hpp" - -BEGIN_ROCPRIM_NAMESPACE -namespace detail -{ - -// invoke_result is based on std::invoke_result. -// The main difference is using ROCPRIM_HOST_DEVICE, this allows to -// use invoke_result with device-only lambdas/functors in host-only functions -// on HIP-clang. - -template -struct is_reference_wrapper : std::false_type {}; -template -struct is_reference_wrapper> : std::true_type {}; - -template -struct invoke_impl { - template - ROCPRIM_HOST_DEVICE - static auto call(F&& f, Args&&... args) - -> decltype(std::forward(f)(std::forward(args)...)); -}; - -template -struct invoke_impl -{ - template::type, - class = typename std::enable_if::value>::type - > - ROCPRIM_HOST_DEVICE - static auto get(T&& t) -> T&&; - - template::type, - class = typename std::enable_if::value>::type - > - ROCPRIM_HOST_DEVICE - static auto get(T&& t) -> decltype(t.get()); - - template::type, - class = typename std::enable_if::value>::type, - class = typename std::enable_if::value>::type - > - ROCPRIM_HOST_DEVICE - static auto get(T&& t) -> decltype(*std::forward(t)); - - template::value>::type - > - ROCPRIM_HOST_DEVICE - static auto call(MT1 B::*pmf, T&& t, Args&&... args) - -> decltype((invoke_impl::get(std::forward(t)).*pmf)(std::forward(args)...)); - - template - ROCPRIM_HOST_DEVICE - static auto call(MT B::*pmd, T&& t) - -> decltype(invoke_impl::get(std::forward(t)).*pmd); -}; - -template::type> -ROCPRIM_HOST_DEVICE -auto INVOKE(F&& f, Args&&... args) - -> decltype(invoke_impl::call(std::forward(f), std::forward(args)...)); - -// Conforming C++14 implementation (is also a valid C++11 implementation): -template -struct invoke_result_impl { }; -template -struct invoke_result_impl(), std::declval()...))), F, Args...> -{ - using type = decltype(INVOKE(std::declval(), std::declval()...)); -}; - -template -struct invoke_result : invoke_result_impl {}; - -template -struct match_result_type -{ - using type = typename invoke_result::type; -}; - -} // end namespace detail -END_ROCPRIM_NAMESPACE - -#endif // ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ diff --git a/rocprim/include/rocprim/detail/temp_storage.hpp b/rocprim/include/rocprim/detail/temp_storage.hpp index b31853165..f9e2a57d6 100644 --- a/rocprim/include/rocprim/detail/temp_storage.hpp +++ b/rocprim/include/rocprim/detail/temp_storage.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -65,7 +65,7 @@ struct simple_partition layout storage_layout; /// Compute the required layout for this type and return it. - layout get_layout() + layout get_layout() const { return this->storage_layout; } @@ -129,13 +129,13 @@ struct linear_partition linear_partition(Ts... sub_partitions) : sub_partitions{sub_partitions...} {} /// Compute the required layout for this type and return it. - layout get_layout() + layout get_layout() const { size_t required_alignment = 1; size_t required_size = 0; for_each_in_tuple(this->sub_partitions, - [&](auto& sub_partition) + [&](const auto& sub_partition) { const auto sub_layout = sub_partition.get_layout(); @@ -197,13 +197,13 @@ struct union_partition union_partition(Ts... sub_partitions) : sub_partitions{sub_partitions...} {} /// Compute the required layout for this type and return it. - layout get_layout() + layout get_layout() const { size_t required_alignment = 1; size_t required_size = 0; for_each_in_tuple(this->sub_partitions, - [&](auto& sub_partition) + [&](const auto& sub_partition) { const auto sub_layout = sub_partition.get_layout(); diff --git a/rocprim/include/rocprim/detail/various.hpp b/rocprim/include/rocprim/detail/various.hpp index 48203e672..1def58e3d 100644 --- a/rocprim/include/rocprim/detail/various.hpp +++ b/rocprim/include/rocprim/detail/various.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -329,19 +329,23 @@ constexpr std::add_const_t* as_const_ptr(T* ptr) return ptr; } -template -ROCPRIM_HOST_DEVICE inline void for_each_in_tuple_impl(::rocprim::tuple& t, - Function f, - ::rocprim::index_sequence) +template +ROCPRIM_HOST_DEVICE inline void + for_each_in_tuple_impl(Tuple&& t, Function&& f, ::rocprim::index_sequence) { - auto swallow = {(f(::rocprim::get(t)), 0)...}; + int swallow[] + = {(std::forward(f)(::rocprim::get(std::forward(t))), 0)...}; (void)swallow; } -template -ROCPRIM_HOST_DEVICE inline void for_each_in_tuple(::rocprim::tuple& t, Function f) +template +ROCPRIM_HOST_DEVICE inline auto for_each_in_tuple(Tuple&& t, Function&& f) + -> void_t>> { - for_each_in_tuple_impl(t, f, ::rocprim::index_sequence_for()); + static constexpr size_t size = tuple_size>::value; + for_each_in_tuple_impl(std::forward(t), + std::forward(f), + ::rocprim::make_index_sequence()); } /// \brief Reinterprets the pointer as another type and increments it to match the alignment of diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp index 8f5cdabdb..139ce7e8b 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,8 +39,10 @@ namespace detail { template -struct default_adjacent_difference_config : default_adjacent_difference_config_base +struct default_adjacent_difference_config + : default_adjacent_difference_config_base::type {}; + // Based on value_type = double template struct default_adjacent_difference_config< diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp index 0718b6e65..ab0df2f81 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -40,7 +40,7 @@ namespace detail template struct default_adjacent_difference_inplace_config - : default_adjacent_difference_config_base + : default_adjacent_difference_config_base::type {}; // Based on value_type = double diff --git a/rocprim/include/rocprim/device/detail/config/device_histogram.hpp b/rocprim/include/rocprim/device/detail/config/device_histogram.hpp index e12d50069..cfd3dc72a 100644 --- a/rocprim/include/rocprim/device/detail/config/device_histogram.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_histogram.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -44,7 +44,7 @@ template struct default_histogram_config - : default_histogram_config_base + : default_histogram_config_base::type {}; // Based on value_type = double, channels = 1, active_channels = 1 diff --git a/rocprim/include/rocprim/device/detail/config/device_reduce.hpp b/rocprim/include/rocprim/device/detail/config/device_reduce.hpp index 54afc42f4..5def2822b 100644 --- a/rocprim/include/rocprim/device/detail/config/device_reduce.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_reduce.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail { template -struct default_reduce_config : default_reduce_config_base +struct default_reduce_config : default_reduce_config_base::type {}; // Based on key_type = double diff --git a/rocprim/include/rocprim/device/detail/config/device_scan.hpp b/rocprim/include/rocprim/device/detail/config/device_scan.hpp index 7fe8e7259..c402b2178 100644 --- a/rocprim/include/rocprim/device/detail/config/device_scan.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail { template -struct default_scan_config : default_scan_config_base +struct default_scan_config : default_scan_config_base::type {}; // Based on value_type = double diff --git a/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp b/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp index f6f375f24..f5b10bce7 100644 --- a/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail { template -struct default_scan_by_key_config : default_scan_by_key_config_base +struct default_scan_by_key_config : default_scan_by_key_config_base::type {}; // Based on key_type = double, value_type = int64_t diff --git a/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp new file mode 100644 index 000000000..51243cfb6 --- /dev/null +++ b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp @@ -0,0 +1,4886 @@ +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ + +#include "../../../type_traits.hpp" +#include "../device_config_helper.hpp" +#include + +/* DO NOT EDIT THIS FILE + * This file is automatically generated by `/scripts/autotune/create_optimization.py`. + * so most likely you want to edit rocprim/device/device_(algo)_config.hpp + */ + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<7, 6>::type +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 3, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config<4, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config<7, + 2, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<16, 2, 256, 5, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 5, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config<6, + 4, + 128, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 128, 5, 8, 8, 128>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<7, + 4, + 128, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 128, 5, 8, 8, 128>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config<6, + 4, + 128, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 128, 5, 8, 8, 128>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 4, + 3, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ \ No newline at end of file diff --git a/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 532b8d8cd..b6947170b 100644 --- a/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -255,7 +255,7 @@ namespace detail { template -struct default_reduce_config_base_helper +struct default_reduce_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -265,10 +265,6 @@ struct default_reduce_config_base_helper ::rocprim::block_reduce_algorithm::using_warp_reduce>; }; -template -struct default_reduce_config_base : default_reduce_config_base_helper::type -{}; - struct scan_config_tag {}; @@ -337,7 +333,7 @@ struct scan_by_key_config_tag {}; template -struct default_scan_config_base_helper +struct default_scan_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -349,10 +345,6 @@ struct default_scan_config_base_helper ::rocprim::block_scan_algorithm::using_warp_scan>; }; -template -struct default_scan_config_base : default_scan_config_base_helper::type -{}; - /// \brief Provides the kernel parameters for exclusive_scan_by_key and inclusive_scan_by_key based /// on autotuned configurations or user-provided configurations. struct scan_by_key_config_params @@ -415,7 +407,7 @@ namespace detail { template -struct default_scan_by_key_config_base_helper +struct default_scan_by_key_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( sizeof(Key) + sizeof(Value), 2 * sizeof(int)); @@ -428,10 +420,6 @@ struct default_scan_by_key_config_base_helper ::rocprim::block_scan_algorithm::using_warp_scan>; }; -template -struct default_scan_by_key_config_base : default_scan_by_key_config_base_helper::type -{}; - struct transform_config_tag {}; @@ -442,6 +430,203 @@ struct transform_config_params } // namespace detail +namespace detail +{ +struct segmented_radix_sort_config_tag +{}; + +struct warp_sort_config_params +{ + /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. + bool partitioning_allowed = false; + /// \brief The number of threads in the logical warp in the small segment processing kernel. + unsigned int logical_warp_size_small = 0; + /// \brief The number of items processed by a thread in the small segment processing kernel. + unsigned int items_per_thread_small = 0; + /// \brief The number of threads per block in the small segment processing kernel. + unsigned int block_size_small = 0; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + unsigned int partitioning_threshold = 0; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + unsigned int logical_warp_size_medium = 0; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + unsigned int items_per_thread_medium = 0; + /// \brief The number of threads per block in the medium segment processing kernel. + unsigned int block_size_medium = 0; +}; + +struct segmented_radix_sort_config_params +{ + /// \brief Kernel start parameters. + kernel_config_params kernel_config{}; + /// \brief Number of bits in long iterations. + unsigned int long_radix_bits = 0; + /// \brief Number of bits in short iterations. + unsigned int short_radix_bits = 0; + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. + bool enable_unpartitioned_warp_sort = true; + /// \brief Warp sort config params + warp_sort_config_params warp_sort_config{}; +}; + +} // namespace detail + +/// \brief Configuration of the warp sort part of the device segmented radix sort operation. +/// Short enough segments are processed on warp level. +/// +/// \tparam LogicalWarpSizeSmall - number of threads in the logical warp of the kernel +/// that processes small segments. +/// \tparam ItemsPerThreadSmall - number of items processed by a thread in the kernel that processes +/// small segments. +/// \tparam BlockSizeSmall - number of threads per block in the kernel which processes the small segments. +/// \tparam PartitioningThreshold - if the number of segments is at least this threshold, the +/// segments are partitioned to a small, a medium and a large segment collection. Both collections +/// are sorted by different kernels. Otherwise, all segments are sorted by a single kernel. +/// \tparam EnableUnpartitionedWarpSort - If set to \p true, warp sort can be used to sort +/// the small segments, even if the total number of segments is below \p PartitioningThreshold. +/// \tparam LogicalWarpSizeMedium - number of threads in the logical warp of the kernel +/// that processes medium segments. +/// \tparam ItemsPerThreadMedium - number of items processed by a thread in the kernel that processes +/// medium segments. +/// \tparam BlockSizeMedium - number of threads per block in the kernel which processes the medium segments. +template +struct WarpSortConfig +{ + static_assert(LogicalWarpSizeSmall * ItemsPerThreadSmall + <= LogicalWarpSizeMedium * ItemsPerThreadMedium, + "The number of items processed by a small warp cannot be larger than the number " + "of items processed by a medium warp"); + + /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. + static constexpr bool partitioning_allowed = true; + /// \brief The number of threads in the logical warp in the small segment processing kernel. + static constexpr unsigned int logical_warp_size_small = LogicalWarpSizeSmall; + /// \brief The number of items processed by a thread in the small segment processing kernel. + static constexpr unsigned int items_per_thread_small = ItemsPerThreadSmall; + /// \brief The number of threads per block in the small segment processing kernel. + static constexpr unsigned int block_size_small = BlockSizeSmall; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + static constexpr unsigned int partitioning_threshold = PartitioningThreshold; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + static constexpr unsigned int logical_warp_size_medium = LogicalWarpSizeMedium; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + static constexpr unsigned int items_per_thread_medium = ItemsPerThreadMedium; + /// \brief The number of threads per block in the medium segment processing kernel. + static constexpr unsigned int block_size_medium = BlockSizeMedium; +}; + +/// \brief Indicates if the warp level sorting is disabled in the +/// device segmented radix sort configuration. +struct DisabledWarpSortConfig +{ + /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. + static constexpr bool partitioning_allowed = false; + /// \brief The number of threads in the logical warp in the small segment processing kernel. + static constexpr unsigned int logical_warp_size_small = 1; + /// \brief The number of items processed by a thread in the small segment processing kernel. + static constexpr unsigned int items_per_thread_small = 1; + /// \brief The number of threads per block in the small segment processing kernel. + static constexpr unsigned int block_size_small = 1; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + static constexpr unsigned int partitioning_threshold = 0; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + static constexpr unsigned int logical_warp_size_medium = 1; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + static constexpr unsigned int items_per_thread_medium = 1; + /// \brief The number of threads per block in the medium segment processing kernel. + static constexpr unsigned int block_size_medium = 1; +}; + +/// \brief Configuration for the device-level segmented radix sort operation. +/// \tparam LongRadixBits . +/// \tparam ShortRadixBits . +/// \tparam BlockSize Number of threads in a block. +/// \tparam ItemsPerThread Number of items processed by each thread. +/// \tparam SizeLimit Limit on the number of items for a single kernel launch. +template +struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_params +{ + /// \brief Identifies the algorithm associated to the config. + /// \brief Identifies the algorithm associated to the config. + using tag = detail::segmented_radix_sort_config_tag; +#ifndef DOXYGEN_SHOULD_SKIP_THIS + + /// \brief Number of bits in long iterations. + static constexpr unsigned int long_radix_bits = LongRadixBits; + + /// \brief Number of bits in short iterations. + static constexpr unsigned int short_radix_bits = ShortRadixBits; + + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. + static constexpr bool enable_unpartitioned_warp_sort = EnableUnpartitionedWarpSort; + + /// \brief Limit on the number of items for a single kernel launch. + static constexpr unsigned int size_limit = SizeLimit; + + using warp_sort_config = WarpSortConfig; + + constexpr segmented_radix_sort_config() + : detail::segmented_radix_sort_config_params{ + {BlockSize, ItemsPerThread, SizeLimit}, + LongRadixBits, + ShortRadixBits, + EnableUnpartitionedWarpSort, + {warp_sort_config::partitioning_allowed, + warp_sort_config::logical_warp_size_small, + warp_sort_config::items_per_thread_small, + warp_sort_config::block_size_small, + warp_sort_config::partitioning_threshold, + warp_sort_config::logical_warp_size_medium, + warp_sort_config::items_per_thread_medium, + warp_sort_config::block_size_medium} + } + {} +#endif +}; + +namespace detail +{ +/// \brief Default segmented_radix_sort kernel configurations, such that the maximum shared memory is not exceeded. +/// +/// \tparam LongRadixBits - Long bits used during the sorting. +/// \tparam ShortRadixBits - Short bits used during the sorting. +/// \tparam ItemsPerThread - Items per thread when type Key has size 1. +template +struct default_segmented_radix_sort_config_base +{ + static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( + sizeof(unsigned int) + sizeof(unsigned int), sizeof(int)); + using type = segmented_radix_sort_config>; +}; + +} // namespace detail + /// \brief Configuration for the device-level transform operation. /// \tparam BlockSize Number of threads in a block. /// \tparam ItemsPerThread Number of items processed by each thread. @@ -476,7 +661,7 @@ namespace detail { template -struct default_transform_config_base_helper +struct default_transform_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -484,10 +669,6 @@ struct default_transform_config_base_helper using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; }; -template -struct default_transform_config_base : default_transform_config_base_helper::type -{}; - struct binary_search_config_tag : public transform_config_tag {}; struct upper_bound_config_tag : public transform_config_tag @@ -597,7 +778,7 @@ namespace detail { template -struct default_histogram_config_base_helper +struct default_histogram_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); @@ -606,11 +787,6 @@ struct default_histogram_config_base_helper = histogram_config>; }; -template -struct default_histogram_config_base - : default_histogram_config_base_helper::type -{}; - struct adjacent_difference_config_tag {}; @@ -657,7 +833,7 @@ namespace detail { template -struct default_adjacent_difference_config_base_helper +struct default_adjacent_difference_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -669,11 +845,6 @@ struct default_adjacent_difference_config_base_helper ::rocprim::block_store_method::block_store_transpose>; }; -template -struct default_adjacent_difference_config_base - : default_adjacent_difference_config_base_helper::type -{}; - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_partition.hpp b/rocprim/include/rocprim/device/detail/device_partition.hpp index 7b49b2200..8fc63acc6 100644 --- a/rocprim/include/rocprim/device/detail/device_partition.hpp +++ b/rocprim/include/rocprim/device/detail/device_partition.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -688,6 +688,153 @@ ROCPRIM_DEVICE void load_selected_count(const size_t* const prev_selected_count, } } +template +class partition_values_helper +{ + +private: + ValueType values[ItemsPerThread]; + +public: + ROCPRIM_DEVICE void load(ValueIterator values_input, + unsigned int valid, + unsigned int is_global_last_block, + typename BlockLoadValueType::storage_type& load_values) + { + // Load values and sync threads + if(is_global_last_block) + { + BlockLoadValueType().load(values_input, values, valid, load_values); + } + else + { + BlockLoadValueType().load(values_input, values, load_values); + } + ::rocprim::syncthreads(); + } + + template + ROCPRIM_DEVICE void store(bool (&is_selected)[ItemsPerThread], + OffsetType (&output_indices)[ItemsPerThread], + OutputType values_output, + const size_t total_size, + const OffsetType selected_prefix, + const OffsetType selected_in_block, + ScatterStorageType& storage, + const unsigned int flat_block_id, + const unsigned int flat_block_thread_id, + const bool is_global_last_block, + const unsigned int valid_in_global_last_block, + size_t (&prev_selected_count_values)[1], + size_t prev_processed) + { + // Sync threads and store values + ::rocprim::syncthreads(); + partition_scatter(values, + is_selected, + output_indices, + values_output, + total_size, + selected_prefix, + selected_in_block, + storage, + flat_block_id, + flat_block_thread_id, + is_global_last_block, + valid_in_global_last_block, + prev_selected_count_values, + prev_processed); + } + + template + ROCPRIM_DEVICE void store(bool (&is_selected)[2][ItemsPerThread], + OffsetType (&output_indices)[ItemsPerThread], + OutputType values_output, + const size_t total_size, + const OffsetType selected_prefix, + const OffsetType selected_in_block, + ScatterStorageType& storage, + const unsigned int flat_block_id, + const unsigned int flat_block_thread_id, + const bool is_global_last_block, + const unsigned int valid_in_global_last_block, + size_t (&prev_selected_count_values)[2], + size_t prev_processed) + { + // Sync threads and store values + ::rocprim::syncthreads(); + partition_scatter(values, + is_selected, + output_indices, + values_output, + total_size, + selected_prefix, + selected_in_block, + storage, + flat_block_id, + flat_block_thread_id, + is_global_last_block, + valid_in_global_last_block, + prev_selected_count_values, + prev_processed); + } +}; + +template +class partition_values_helper +{ +public: + ROCPRIM_DEVICE void + load(ValueIterator, unsigned int, unsigned int, typename BlockLoadValueType::storage_type&) + {} + + template + ROCPRIM_DEVICE void store(bool[ItemsPerThread], + OffsetType[ItemsPerThread], + OutputType, + const size_t, + const OffsetType, + const OffsetType, + ScatterStorageType&, + const unsigned int, + const unsigned int, + const bool, + const unsigned int, + size_t[ItemsPerThread], + size_t) + {} + + template + ROCPRIM_DEVICE void store(bool[2][ItemsPerThread], + OffsetType[ItemsPerThread], + OutputType, + const size_t, + const OffsetType, + const OffsetType, + ScatterStorageType&, + const unsigned int, + const unsigned int, + const bool, + const unsigned int, + size_t[ItemsPerThread], + size_t) + {} +}; + template + values_helper; + values_helper.load(values_input + block_offset, + valid_in_global_last_block, + is_global_last_block, + storage.load_values); + // Load selection flags into is_selected, generate them using // input value and selection predicate, or generate them using // block_discontinuity primitive @@ -885,45 +1045,19 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void prev_selected_count_values, prev_processed); - static constexpr bool with_values = !std::is_same::value; - - if ROCPRIM_IF_CONSTEXPR (with_values) { - value_type values[items_per_thread]; - - ::rocprim::syncthreads(); // sync threads to reuse shared memory - if(is_global_last_block) - { - block_load_value_type().load(values_input + block_offset, - values, - valid_in_global_last_block, - storage.load_values); - } - else - { - block_load_value_type() - .load( - values_input + block_offset, - values, - storage.load_values - ); - } - ::rocprim::syncthreads(); // sync threads to reuse shared memory - - partition_scatter(values, - is_selected, - output_indices, - values_output, - total_size, - selected_prefix, - selected_in_block, - storage.exchange_values, - flat_block_id, - flat_block_thread_id, - is_global_last_block, - valid_in_global_last_block, - prev_selected_count_values, - prev_processed); - } + values_helper.store(is_selected, + output_indices, + values_output, + total_size, + selected_prefix, + selected_in_block, + storage.exchange_values, + flat_block_id, + flat_block_thread_id, + is_global_last_block, + valid_in_global_last_block, + prev_selected_count_values, + prev_processed); // Last block in grid stores number of selected values const bool is_last_block = flat_block_id == (number_of_blocks - 1); diff --git a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp index 16bc5b866..4c0e6f21a 100644 --- a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,6 @@ #include "../../block/block_load.hpp" #include "../../block/block_scan.hpp" #include "../../block/block_store.hpp" -#include "../../detail/match_result_type.hpp" #include "../../detail/various.hpp" #include "../../intrinsics/thread.hpp" #include "../../thread/thread_operators.hpp" @@ -51,7 +50,7 @@ using value_type_t = typename std::iterator_traits::value_type; template using accumulator_type_t = - typename detail::match_result_type, BinaryOp>::type; + typename invoke_result_binary_op, BinaryOp>::type; template using wrapped_type_t = rocprim::tuple; diff --git a/rocprim/include/rocprim/device/detail/device_scan.hpp b/rocprim/include/rocprim/device/detail/device_scan.hpp index b11a1cfe2..90b31f7bf 100644 --- a/rocprim/include/rocprim/device/detail/device_scan.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan.hpp @@ -93,23 +93,22 @@ template ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void lookback_scan_kernel_impl(InputIterator input, OutputIterator output, const size_t size, - ResultType initial_value, + AccType initial_value, BinaryFunction scan_op, LookbackScanState scan_state, const unsigned int number_of_blocks, - ResultType* previous_last_element = nullptr, - ResultType* new_last_element = nullptr, + AccType* previous_last_element = nullptr, + AccType* new_last_element = nullptr, bool override_first_value = false, bool save_last_value = false) { - using result_type = ResultType; - static_assert(std::is_same::value, + static_assert(std::is_same::value, "value_type of LookbackScanState must be result_type"); static constexpr scan_config_params params = device_params(); @@ -117,15 +116,14 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void constexpr auto items_per_thread = params.kernel_config.items_per_thread; constexpr unsigned int items_per_block = block_size * items_per_thread; - using block_load_type = ::rocprim:: - block_load; - using block_store_type = ::rocprim:: - block_store; - using block_scan_type - = ::rocprim::block_scan; + using block_load_type + = ::rocprim::block_load; + using block_store_type + = ::rocprim::block_store; + using block_scan_type = ::rocprim::block_scan; using lookback_scan_prefix_op_type - = lookback_scan_prefix_op; + = lookback_scan_prefix_op; ROCPRIM_SHARED_MEMORY union { @@ -140,7 +138,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void const auto valid_in_last_block = size - items_per_block * (number_of_blocks - 1); // For input values - result_type values[items_per_thread]; + AccType values[items_per_thread]; // load input values into values if(flat_block_id == (number_of_blocks - 1)) // last block @@ -165,12 +163,12 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void { if(Exclusive) initial_value - = scan_op(previous_last_element[0], static_cast(*(input - 1))); + = scan_op(previous_last_element[0], static_cast(*(input - 1))); else if(flat_block_thread_id == 0) values[0] = scan_op(previous_last_element[0], values[0]); } - result_type reduction; + AccType reduction; lookback_block_scan(values, // input/output initial_value, reduction, diff --git a/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp b/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp index fe481b0b2..1efb9c6fd 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp @@ -32,6 +32,7 @@ #include "../../detail/binary_op_wrappers.hpp" #include "../../intrinsics/thread.hpp" #include "../../types/tuple.hpp" +#include "rocprim/device/detail/device_config_helper.hpp" #include diff --git a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp index b1a8cf33b..52fd28086 100644 --- a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -486,21 +486,14 @@ struct DisabledWarpSortHelperConfig static constexpr unsigned int block_size = 1; }; -template -using select_warp_sort_helper_config_small_t - = std::conditional_t::value, - DisabledWarpSortHelperConfig, - WarpSortHelperConfig>; - -template -using select_warp_sort_helper_config_medium_t - = std::conditional_t::value, - DisabledWarpSortHelperConfig, - WarpSortHelperConfig>; +template +using select_warp_sort_helper_config_t + = std::conditional_t, + DisabledWarpSortHelperConfig>; template< class Config, @@ -705,12 +698,14 @@ void segmented_sort(KeysInputIterator keys_input, unsigned int begin_bit, unsigned int end_bit) { - constexpr unsigned int long_radix_bits = Config::long_radix_bits; - constexpr unsigned int short_radix_bits = Config::short_radix_bits; - constexpr unsigned int block_size = Config::sort::block_size; - constexpr unsigned int items_per_thread = Config::sort::items_per_thread; - constexpr unsigned int items_per_block = block_size * items_per_thread; - constexpr bool warp_sort_enabled = Config::warp_sort_config::enable_unpartitioned_warp_sort; + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int long_radix_bits = params.long_radix_bits; + static constexpr unsigned int short_radix_bits = params.short_radix_bits; + static constexpr unsigned int block_size = params.kernel_config.block_size; + static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; + static constexpr bool warp_sort_enabled = params.enable_unpartitioned_warp_sort; using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; @@ -731,7 +726,10 @@ void segmented_sort(KeysInputIterator keys_input, short_radix_bits, Descending >; using warp_sort_helper_type = segmented_warp_sort_helper< - select_warp_sort_helper_config_small_t, + select_warp_sort_helper_config_t, key_type, value_type, Descending>; @@ -798,7 +796,7 @@ void segmented_sort(KeysInputIterator keys_input, storage.single_block_helper ); } - else if(::rocprim::flat_block_thread_id() < Config::warp_sort_config::logical_warp_size_small) + else if(::rocprim::flat_block_thread_id() < params.warp_sort_config.logical_warp_size_small) { // Single warp segment warp_sort_helper_type().sort( @@ -837,11 +835,13 @@ void segmented_sort_large(KeysInputIterator keys_input, unsigned int begin_bit, unsigned int end_bit) { - constexpr unsigned int long_radix_bits = Config::long_radix_bits; - constexpr unsigned int short_radix_bits = Config::short_radix_bits; - constexpr unsigned int block_size = Config::sort::block_size; - constexpr unsigned int items_per_thread = Config::sort::items_per_thread; - constexpr unsigned int items_per_block = block_size * items_per_thread; + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int long_radix_bits = params.long_radix_bits; + static constexpr unsigned int short_radix_bits = params.short_radix_bits; + static constexpr unsigned int block_size = params.kernel_config.block_size; + static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; @@ -946,17 +946,101 @@ void segmented_sort_small(KeysInputIterator keys_input, unsigned int begin_bit, unsigned int end_bit) { - static constexpr unsigned int block_size = Config::block_size; - static constexpr unsigned int logical_warp_size = Config::logical_warp_size; - static_assert(block_size % logical_warp_size == 0, "logical_warp_size must be a divisor of block_size"); + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int block_size = params.warp_sort_config.block_size_small; + static constexpr unsigned int logical_warp_size + = params.warp_sort_config.logical_warp_size_small; + static_assert(block_size % logical_warp_size == 0, + "logical_warp_size must be a divisor of block_size"); + static constexpr unsigned int warps_per_block = block_size / logical_warp_size; + + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using warp_sort_helper_type = segmented_warp_sort_helper< + select_warp_sort_helper_config_t, + key_type, + value_type, + Descending>; + + ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; + + const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int logical_warp_id = ::rocprim::detail::logical_warp_id(); + const unsigned int segment_index = block_id * warps_per_block + logical_warp_id; + if(segment_index >= num_segments) + { + return; + } + + const unsigned int segment_id = segment_indices[segment_index]; + const unsigned int begin_offset = begin_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; + if(end_offset <= begin_offset) + { + return; + } + warp_sort_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); +} + +template +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void segmented_sort_medium( + KeysInputIterator keys_input, + typename std::iterator_traits::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type* values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) +{ + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int block_size = params.warp_sort_config.block_size_medium; + static constexpr unsigned int logical_warp_size + = params.warp_sort_config.logical_warp_size_medium; + static_assert(block_size % logical_warp_size == 0, + "logical_warp_size must be a divisor of block_size"); static constexpr unsigned int warps_per_block = block_size / logical_warp_size; using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; using warp_sort_helper_type = segmented_warp_sort_helper< - Config, key_type, value_type, Descending - >; + select_warp_sort_helper_config_t, + key_type, + value_type, + Descending>; ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; diff --git a/rocprim/include/rocprim/device/detail/device_transform.hpp b/rocprim/include/rocprim/device/detail/device_transform.hpp index e84ee0731..1f966e264 100644 --- a/rocprim/include/rocprim/device/detail/device_transform.hpp +++ b/rocprim/include/rocprim/device/detail/device_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,6 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../detail/match_result_type.hpp" #include "../../intrinsics.hpp" #include "../../functional.hpp" @@ -45,7 +44,7 @@ namespace detail template struct unpack_binary_op { - using result_type = typename ::rocprim::detail::invoke_result::type; + using result_type = typename ::rocprim::invoke_result::type; ROCPRIM_HOST_DEVICE inline unpack_binary_op() = default; diff --git a/rocprim/include/rocprim/device/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/device_adjacent_difference.hpp index 917bf4328..13476ce43 100644 --- a/rocprim/include/rocprim/device/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/device_adjacent_difference.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -154,9 +154,10 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, // next block, otherwise the first item is needed for the previous block const auto offset = items_per_block - (Right ? 0 : 1); - const auto block_starts_iter = make_transform_iterator( - rocprim::make_counting_iterator(std::size_t{0}), - [=, base = input + offset](std::size_t i) { return base[i * items_per_block]; }); + const auto block_starts_iter + = make_transform_iterator(rocprim::make_counting_iterator(std::size_t{0}), + [=, base = input + offset](std::size_t i) -> value_type + { return base[i * items_per_block]; }); const hipError_t error = ::rocprim::transform(block_starts_iter, previous_values, @@ -244,26 +245,26 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param input - iterator to the input range -/// \param output - iterator to the output range, must have any overlap with input -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the input range +/// \param output iterator to the output range, must not have any overlap with input. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -340,23 +341,23 @@ hipError_t adjacent_difference(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param values - iterator to the range values, will be overwritten with the results -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param values iterator to the range values, will be overwritten with the results +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -379,6 +380,65 @@ hipError_t adjacent_difference_inplace(void* const temporary_storage, temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); } +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the left item. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be +/// `adjacent_difference_config` or a class derived from it. +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it +/// +/// \param temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the range values +/// \param output iterator to the output range. Allowed to point to the same elements as `input`. +/// Only complete overlap or no overlap at all is allowed between `input` and `output`. In other words +/// writing to `output[i]` is only allowed to overwrite `input[i]`, any other element must not be changed. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `hipSuccess` (0) on success, otherwise the HIP runtime error of +/// type `hipError_t` +/// +/// \note This function has to perform an extra copy due to (potentially) writing its values in-place. If it is known that `input` and `output` +/// don't overlap then adjacent_difference should be preferred as it avoids this extra copy. +template> +hipError_t adjacent_difference_inplace(void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op = BinaryFunction{}, + const hipStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = true; + static constexpr bool right = false; + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + input, + output, + size, + op, + stream, + debug_synchronous); +} + /// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements /// in device accessible memory. Writes the output to the position of the right item. /// @@ -393,26 +453,26 @@ hipError_t adjacent_difference_inplace(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param input - iterator to the input range -/// \param output - iterator to the output range, must have any overlap with input -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the input range +/// \param output iterator to the output range, must not have any overlap with input. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -489,23 +549,23 @@ hipError_t adjacent_difference_right(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param values - iterator to the range values, will be overwritten with the results -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param values iterator to the range values, will be overwritten with the results +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -528,6 +588,64 @@ hipError_t adjacent_difference_right_inplace(void* const temporary_stor temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); } +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the right item. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be +/// `adjacent_difference_config` or a class derived from it. +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it + +/// \param temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the range values, will be overwritten with the results +/// \param output iterator to the output range. Allowed to point to the same elements as `input`. +/// Only complete overlap or no overlap at all is allowed between `input` and `output`. In other words +/// writing to `output[i]` is only allowed to overwrite `input[i]`, any other element must not be changed. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `hipSuccess` (0) on success, otherwise the HIP runtime error of +/// type `hipError_t` +/// \note This function has to perform an extra copy due to (potentially) writing its values in-place. If it is known that `input` and `output` +/// don't overlap then adjacent_difference_right should be preferred as it avoids this extra copy. +template> +hipError_t adjacent_difference_right_inplace(void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op = BinaryFunction{}, + const hipStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = true; + static constexpr bool right = true; + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + input, + output, + size, + op, + stream, + debug_synchronous); +} + /// @} // end of group devicemodule diff --git a/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp b/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp index 0299484f1..e68391dd0 100644 --- a/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp +++ b/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp @@ -83,6 +83,11 @@ struct wrapped_adjacent_difference_config }; #ifndef DOXYGEN_SHOULD_SKIP_THIS +template +template +constexpr adjacent_difference_config_params + wrapped_adjacent_difference_config:: + architecture_config::params; template template constexpr adjacent_difference_config_params diff --git a/rocprim/include/rocprim/device/device_reduce.hpp b/rocprim/include/rocprim/device/device_reduce.hpp index 649e1c583..f72339b15 100644 --- a/rocprim/include/rocprim/device/device_reduce.hpp +++ b/rocprim/include/rocprim/device/device_reduce.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,6 @@ #include "config_types.hpp" #include "../config.hpp" -#include "../detail/match_result_type.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" @@ -112,9 +111,8 @@ hipError_t reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::detail::match_result_type< - input_type, BinaryFunction - >::type; + using result_type = + typename ::rocprim::invoke_result_binary_op::type; using config = wrapped_reduce_config; diff --git a/rocprim/include/rocprim/device/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/device_reduce_by_key.hpp index 315f75668..db63d50d5 100644 --- a/rocprim/include/rocprim/device/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/device_reduce_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -30,7 +30,6 @@ #include "detail/lookback_scan_state.hpp" #include "../config.hpp" -#include "../detail/match_result_type.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" #include "../functional.hpp" diff --git a/rocprim/include/rocprim/device/device_scan.hpp b/rocprim/include/rocprim/device/device_scan.hpp index 2cb648dad..305bd3416 100644 --- a/rocprim/include/rocprim/device/device_scan.hpp +++ b/rocprim/include/rocprim/device/device_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -51,10 +51,10 @@ template + class AccType> ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void single_scan_kernel_impl(InputIterator input, const size_t input_size, - ResultType initial_value, + AccType initial_value, OutputIterator output, BinaryFunction scan_op) { @@ -63,14 +63,11 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void single_scan_kernel_impl(InputIterator constexpr unsigned int block_size = params.kernel_config.block_size; constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; - using result_type = ResultType; - - using block_load_type = ::rocprim:: - block_load; - using block_store_type = ::rocprim:: - block_store; - using block_scan_type - = ::rocprim::block_scan; + using block_load_type + = ::rocprim::block_load; + using block_store_type + = ::rocprim::block_store; + using block_scan_type = ::rocprim::block_scan; ROCPRIM_SHARED_MEMORY union { @@ -79,7 +76,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void single_scan_kernel_impl(InputIterator typename block_scan_type::storage_type scan; } storage; - result_type values[items_per_thread]; + AccType values[items_per_thread]; // load input values into values block_load_type().load(input, values, input_size, *(input), storage.load); ::rocprim::syncthreads(); // sync threads to reuse shared memory @@ -100,7 +97,8 @@ template + class InitValueType, + class AccType> ROCPRIM_KERNEL __launch_bounds__(device_params().kernel_config.block_size) void single_scan_kernel( InputIterator input, @@ -111,7 +109,7 @@ ROCPRIM_KERNEL { single_scan_kernel_impl(input, size, - get_input_value(initial_value), + static_cast(get_input_value(initial_value)), output, scan_op); } @@ -124,32 +122,34 @@ template ROCPRIM_KERNEL __launch_bounds__(device_params().kernel_config.block_size) void lookback_scan_kernel( - InputIterator input, - OutputIterator output, - const size_t size, - const InitValueType initial_value, - BinaryFunction scan_op, - LookBackScanState lookback_scan_state, - const unsigned int number_of_blocks, - input_type_t* previous_last_element = nullptr, - input_type_t* new_last_element = nullptr, - bool override_first_value = false, - bool save_last_value = false) + InputIterator input, + OutputIterator output, + const size_t size, + const InitValueType initial_value, + BinaryFunction scan_op, + LookBackScanState lookback_scan_state, + const unsigned int number_of_blocks, + AccType* previous_last_element = nullptr, + AccType* new_last_element = nullptr, + bool override_first_value = false, + bool save_last_value = false) { - lookback_scan_kernel_impl(input, - output, - size, - get_input_value(initial_value), - scan_op, - lookback_scan_state, - number_of_blocks, - previous_last_element, - new_last_element, - override_first_value, - save_last_value); + lookback_scan_kernel_impl( + input, + output, + size, + static_cast(get_input_value(initial_value)), + scan_op, + lookback_scan_state, + number_of_blocks, + previous_last_element, + new_last_element, + override_first_value, + save_last_value); } #define ROCPRIM_DETAIL_HIP_SYNC(name, size, start) \ @@ -183,7 +183,8 @@ template + class BinaryFunction, + class AccType> inline auto scan_impl(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -194,9 +195,7 @@ inline auto scan_impl(void* temporary_storage, const hipStream_t stream, bool debug_synchronous) { - using real_init_value_type = input_type_t; - - using config = wrapped_scan_config; + using config = wrapped_scan_config; detail::target_arch target_arch; hipError_t result = host_target_arch(stream, target_arch); @@ -206,8 +205,8 @@ inline auto scan_impl(void* temporary_storage, } const scan_config_params params = dispatch_target_arch(target_arch); - using scan_state_type = detail::lookback_scan_state; - using scan_state_with_sleep_type = detail::lookback_scan_state; + using scan_state_type = detail::lookback_scan_state; + using scan_state_with_sleep_type = detail::lookback_scan_state; const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -222,9 +221,9 @@ inline auto scan_impl(void* temporary_storage, unsigned int number_of_blocks = (limited_size + items_per_block - 1)/items_per_block; // Pointer to array with block_prefixes - void* scan_state_storage; - real_init_value_type* previous_last_element; - real_init_value_type* new_last_element; + void* scan_state_storage; + AccType* previous_last_element; + AccType* new_last_element; detail::temp_storage::layout layout{}; hipError_t layout_result @@ -304,24 +303,14 @@ inline auto scan_impl(void* temporary_storage, if(std::string(prop.gcnArchName).find("908") != std::string::npos && asicRevision < 2) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME(init_lookback_scan_state_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - scan_state_with_sleep, - number_of_blocks); + init_lookback_scan_state_kernel + <<>>(scan_state_with_sleep, + number_of_blocks); } else { - hipLaunchKernelGGL( - HIP_KERNEL_NAME(init_lookback_scan_state_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - scan_state, - number_of_blocks); + init_lookback_scan_state_kernel + <<>>(scan_state, + number_of_blocks); } ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_lookback_scan_state_kernel", number_of_blocks, start) @@ -329,30 +318,25 @@ inline auto scan_impl(void* temporary_storage, grid_size = number_of_blocks; if(std::string(prop.gcnArchName).find("908") != std::string::npos && asicRevision < 2) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - lookback_scan_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - input + offset, - output + offset, - current_size, - initial_value, - scan_op, - scan_state_with_sleep, - number_of_blocks, - previous_last_element, - new_last_element, - i != size_t(0), - number_of_launch > 1); + lookback_scan_kernel + <<>>(input + offset, + output + offset, + current_size, + initial_value, + scan_op, + scan_state_with_sleep, + number_of_blocks, + previous_last_element, + new_last_element, + i != size_t(0), + number_of_launch > 1); } else { @@ -366,30 +350,25 @@ inline auto scan_impl(void* temporary_storage, std::cout << "items_per_block " << items_per_block << '\n'; } - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - lookback_scan_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - input + offset, - output + offset, - current_size, - initial_value, - scan_op, - scan_state, - number_of_blocks, - previous_last_element, - new_last_element, - i != size_t(0), - number_of_launch > 1); + lookback_scan_kernel + <<>>(input + offset, + output + offset, + current_size, + initial_value, + scan_op, + scan_state, + number_of_blocks, + previous_last_element, + new_last_element, + i != size_t(0), + number_of_launch > 1); } ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start) @@ -399,7 +378,7 @@ inline auto scan_impl(void* temporary_storage, hipError_t error = ::rocprim::transform(new_last_element, previous_last_element, 1, - ::rocprim::identity(), + ::rocprim::identity(), stream, debug_synchronous); if(error != hipSuccess) return error; @@ -414,24 +393,17 @@ inline auto scan_impl(void* temporary_storage, std::cout << "block_size " << block_size << '\n'; std::cout << "number of blocks " << number_of_blocks << '\n'; std::cout << "items_per_block " << items_per_block << '\n'; + start = std::chrono::high_resolution_clock::now(); } - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(single_scan_kernel), - dim3(1), - dim3(block_size), - 0, - stream, - input, - size, - initial_value, - output, - scan_op); + single_scan_kernel + <<>>(input, size, initial_value, output, scan_op); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("single_scan_kernel", size, start); } return hipSuccess; @@ -455,7 +427,7 @@ inline auto scan_impl(void* temporary_storage, /// if \p temporary_storage in a null pointer. /// * Ranges specified by \p input and \p output must have at least \p size elements. /// * By default, the input type is used for accumulation. A custom type -/// can be specified using rocprim::transform_iterator, see the example below. +/// can be specified using the \p AccType type parameter, see the example below. /// /// \tparam Config - [optional] configuration of the primitive, has to be \p scan_config or a class derived from it. /// \tparam InputIterator - random-access iterator type of the input range. Must meet the @@ -464,6 +436,8 @@ inline auto scan_impl(void* temporary_storage, /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. /// \tparam BinaryFunction - type of binary function used for scan. Default type /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is value type of the input iterator. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -526,31 +500,35 @@ inline auto scan_impl(void* temporary_storage, /// short * input; /// int * output; /// -/// // Use a transform iterator to specify a custom accumulator type -/// auto input_iterator = rocprim::make_transform_iterator( -/// input, [] __device__ (T in) { return static_cast(in); }); -/// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; -/// // Use the transform iterator +/// /// rocprim::inclusive_scan( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// input_iterator, output, input_size, rocprim::plus() +/// input, output, input_size, rocprim::plus() /// ); /// /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// -/// rocprim::inclusive_scan( -/// temporary_storage_ptr, temporary_storage_size_bytes, -/// input_iterator, output, input_size, rocprim::plus() -/// ); +/// // Use type parameter to set custom accumulator type +/// rocprim::inclusive_scan, +/// int>(temporary_storage_ptr, +/// temporary_storage_size_bytes, +/// input_iterator, +/// output, +/// input_size, +/// rocprim::plus()); /// \endcode /// \endparblock template::value_type>> + = ::rocprim::plus::value_type>, + class AccType = typename std::iterator_traits::value_type> inline hipError_t inclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -560,17 +538,18 @@ inline hipError_t inclusive_scan(void* temporary_storage, const hipStream_t stream = 0, bool debug_synchronous = false) { - using input_type = typename std::iterator_traits::value_type; // input_type() is a dummy initial value (not used) - return detail::scan_impl(temporary_storage, - storage_size, - input, - output, - input_type(), - size, - scan_op, - stream, - debug_synchronous); + return detail:: + scan_impl( + temporary_storage, + storage_size, + input, + output, + AccType{}, + size, + scan_op, + stream, + debug_synchronous); } /// \brief Parallel exclusive scan primitive for device level. @@ -594,6 +573,8 @@ inline hipError_t inclusive_scan(void* temporary_storage, /// \tparam InitValueType - type of the initial value. /// \tparam BinaryFunction - type of binary function used for scan. Default type /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is 'InitValueType', unless it's 'rocprim::future_value'. Then it will be the wrapped input type. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -654,7 +635,7 @@ inline hipError_t inclusive_scan(void* temporary_storage, /// temporary_storage_ptr, temporary_storage_size_bytes, /// input, output, start_value, input_size, min_op /// ); -/// // output: [9, 4, 7, 6, 2, 2, 1, 1] +/// // output: [9, 4, 4, 4, 2, 2, 1, 1] /// \endcode /// \endparblock template::value_type>> + = ::rocprim::plus::value_type>, + class AccType = detail::input_type_t> inline hipError_t exclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -673,15 +655,21 @@ inline hipError_t exclusive_scan(void* temporary_storage, const hipStream_t stream = 0, bool debug_synchronous = false) { - return detail::scan_impl(temporary_storage, - storage_size, - input, - output, - initial_value, - size, - scan_op, - stream, - debug_synchronous); + return detail::scan_impl(temporary_storage, + storage_size, + input, + output, + initial_value, + size, + scan_op, + stream, + debug_synchronous); } /// @} diff --git a/rocprim/include/rocprim/device/device_scan_by_key.hpp b/rocprim/include/rocprim/device/device_scan_by_key.hpp index 8d8c4fecb..73d5fcb0b 100644 --- a/rocprim/include/rocprim/device/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/device_scan_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -54,31 +54,32 @@ template + typename AccType> void __global__ __launch_bounds__(device_params().kernel_config.block_size) - device_scan_by_key_kernel(const KeyInputIterator keys, - const InputIterator values, - const OutputIterator output, - const InitialValueType initial_value, - const CompareFunction compare, - const BinaryFunction scan_op, - const LookbackScanState scan_state, - const size_t size, - const size_t starting_block, - const size_t number_of_blocks, - const ::rocprim::tuple* const previous_last_value) + device_scan_by_key_kernel(const KeyInputIterator keys, + const InputIterator values, + const OutputIterator output, + const InitialValueType initial_value, + const CompareFunction compare, + const BinaryFunction scan_op, + const LookbackScanState scan_state, + const size_t size, + const size_t starting_block, + const size_t number_of_blocks, + const ::rocprim::tuple* const previous_last_value) { - device_scan_by_key_kernel_impl(keys, - values, - output, - get_input_value(initial_value), - compare, - scan_op, - scan_state, - size, - starting_block, - number_of_blocks, - previous_last_value); + device_scan_by_key_kernel_impl( + keys, + values, + output, + static_cast(get_input_value(initial_value)), + compare, + scan_op, + scan_state, + size, + starting_block, + number_of_blocks, + previous_last_value); } #define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ @@ -106,7 +107,8 @@ template + typename CompareFunction, + typename AccType> inline hipError_t scan_by_key_impl(void* const temporary_storage, size_t& storage_size, KeysInputIterator keys, @@ -119,10 +121,9 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, const hipStream_t stream, const bool debug_synchronous) { - using key_type = typename std::iterator_traits::value_type; - using real_init_value_type = input_type_t; + using key_type = typename std::iterator_traits::value_type; - using config = wrapped_scan_by_key_config; + using config = wrapped_scan_by_key_config; detail::target_arch target_arch; hipError_t result = host_target_arch(stream, target_arch); @@ -132,7 +133,7 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, } const scan_by_key_config_params params = dispatch_target_arch(target_arch); - using wrapped_type = ::rocprim::tuple; + using wrapped_type = ::rocprim::tuple; using scan_state_type = detail::lookback_scan_state; using scan_state_with_sleep_type = detail::lookback_scan_state; @@ -284,7 +285,7 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, keys + offset, input + offset, output + offset, - static_cast(initial_value), + initial_value, compare, scan_op, scan_state, @@ -331,6 +332,8 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. /// \tparam KeyCompareFunction - type of binary function used to determine keys equality. Default type /// is \p rocprim::equal_to, where \p T is a \p value_type of \p KeysInputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is value type of the input iterator. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -402,7 +405,8 @@ template::value_type>, typename KeyCompareFunction - = ::rocprim::equal_to::value_type>> + = ::rocprim::equal_to::value_type>, + typename AccType = typename std::iterator_traits::value_type> inline hipError_t inclusive_scan_by_key(void* const temporary_storage, size_t& storage_size, const KeysInputIterator keys_input, @@ -416,17 +420,25 @@ inline hipError_t inclusive_scan_by_key(void* const temporary_sto const bool debug_synchronous = false) { using value_type = typename std::iterator_traits::value_type; - return detail::scan_by_key_impl(temporary_storage, - storage_size, - keys_input, - values_input, - values_output, - value_type(), - size, - scan_op, - key_compare_op, - stream, - debug_synchronous); + return detail::scan_by_key_impl(temporary_storage, + storage_size, + keys_input, + values_input, + values_output, + value_type(), + size, + scan_op, + key_compare_op, + stream, + debug_synchronous); } /// \brief Parallel exclusive scan-by-key primitive for device level. @@ -455,6 +467,8 @@ inline hipError_t inclusive_scan_by_key(void* const temporary_sto /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. /// \tparam KeyCompareFunction - type of binary function used to determine keys equality. Default type /// is \p rocprim::equal_to, where \p T is a \p value_type of \p KeysInputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is 'InitValueType', unless it's 'rocprim::future_value'. Then it will be the wrapped input type. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -530,7 +544,8 @@ template::value_type>, typename KeyCompareFunction - = ::rocprim::equal_to::value_type>> + = ::rocprim::equal_to::value_type>, + typename AccType = detail::input_type_t> inline hipError_t exclusive_scan_by_key(void* const temporary_storage, size_t& storage_size, const KeysInputIterator keys_input, @@ -544,17 +559,25 @@ inline hipError_t exclusive_scan_by_key(void* const temporary_sto const hipStream_t stream = 0, const bool debug_synchronous = false) { - return detail::scan_by_key_impl(temporary_storage, - storage_size, - keys_input, - values_input, - values_output, - initial_value, - size, - scan_op, - key_compare_op, - stream, - debug_synchronous); + return detail::scan_by_key_impl(temporary_storage, + storage_size, + keys_input, + values_input, + values_output, + initial_value, + size, + scan_op, + key_compare_op, + stream, + debug_synchronous); } /// @} diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 576789f06..58089b810 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -25,10 +25,12 @@ #include #include #include +#include #include "../config.hpp" -#include "../detail/various.hpp" #include "../detail/radix_sort.hpp" +#include "../detail/various.hpp" +#include "config_types.hpp" #include "../intrinsics.hpp" #include "../functional.hpp" @@ -49,31 +51,28 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class Config, - bool Descending, - unsigned int BlockSize, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator, - class OffsetIterator -> +template ROCPRIM_KERNEL -__launch_bounds__(BlockSize) -void segmented_sort_kernel(KeysInputIterator keys_input, - typename std::iterator_traits::value_type * keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits::value_type * values_tmp, - ValuesOutputIterator values_output, - bool to_output, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, - unsigned int begin_bit, - unsigned int end_bit) + __launch_bounds__(device_params().kernel_config.block_size) void segmented_sort_kernel( + KeysInputIterator keys_input, + typename std::iterator_traits::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type* values_tmp, + ValuesOutputIterator values_output, + bool to_output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) { segmented_sort( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -84,33 +83,34 @@ void segmented_sort_kernel(KeysInputIterator keys_input, ); } -template< - class Config, - bool Descending, - unsigned int BlockSize, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator, - class SegmentIndexIterator, - class OffsetIterator -> -ROCPRIM_KERNEL -__launch_bounds__(BlockSize) -void segmented_sort_large_kernel(KeysInputIterator keys_input, - typename std::iterator_traits::value_type * keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits::value_type * values_tmp, - ValuesOutputIterator values_output, - bool to_output, - SegmentIndexIterator segment_indices, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, - unsigned int begin_bit, - unsigned int end_bit) +template +ROCPRIM_KERNEL __launch_bounds__( + device_params() + .kernel_config + .block_size) void segmented_sort_large_kernel(KeysInputIterator keys_input, + typename std::iterator_traits< + KeysInputIterator>::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits< + ValuesInputIterator>::value_type* + values_tmp, + ValuesOutputIterator values_output, + bool to_output, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) { segmented_sort_large( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -122,28 +122,33 @@ void segmented_sort_large_kernel(KeysInputIterator keys_input, } template -ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_kernel( - KeysInputIterator keys_input, - typename std::iterator_traits::value_type* keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits::value_type* values_tmp, - ValuesOutputIterator values_output, - bool to_output, - unsigned int num_segments, - SegmentIndexIterator segment_indices, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int begin_bit, - unsigned int end_bit) +ROCPRIM_KERNEL __launch_bounds__( + device_params() + .warp_sort_config + .block_size_small) void segmented_sort_small_kernel(KeysInputIterator keys_input, + typename std::iterator_traits< + KeysInputIterator>::value_type* + keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits< + ValuesInputIterator>::value_type* + values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) { segmented_sort_small( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -153,6 +158,50 @@ ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_ ); } +template +ROCPRIM_KERNEL __launch_bounds__( + device_params() + .warp_sort_config + .block_size_medium) void segmented_sort_medium_kernel(KeysInputIterator keys_input, + typename std::iterator_traits< + KeysInputIterator>::value_type* + keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits< + ValuesInputIterator>::value_type* + values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) +{ + segmented_sort_medium(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + num_segments, + segment_indices, + begin_offsets, + end_offsets, + begin_bit, + end_bit); +} + #define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ { \ auto _error = hipGetLastError(); \ @@ -168,42 +217,12 @@ ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_ } \ } -struct TwoWayPartitioner +struct Partitioner { - template - hipError_t operator()(void* temporary_storage, - size_t& storage_size, - InputIterator input, - FirstOutputIterator output_first_part, - SecondOutputIterator /*output_second_part*/, - UnselectedOutputIterator /*output_unselected*/, - SelectedCountOutputIterator selected_count_output, - const size_t size, - FirstUnaryPredicate select_first_part_op, - SecondUnaryPredicate /*select_second_part_op*/, - const hipStream_t stream, - const bool debug_synchronous) - { - return partition(temporary_storage, - storage_size, - input, - output_first_part, - selected_count_output, - size, - select_first_part_op, - stream, - debug_synchronous); - } -}; + bool three_way_partitioning; + + Partitioner(bool three_way_part) : three_way_partitioning(three_way_part) {} -struct ThreeWayPartitioner -{ template - >; + using config = wrapped_segmented_radix_sort_config; + + detail::target_arch target_arch; + hipError_t result = detail::host_target_arch(stream, target_arch); + if(result != hipSuccess) + { + return result; + } + + const detail::segmented_radix_sort_config_params params + = detail::dispatch_target_arch(target_arch); static constexpr bool with_values = !std::is_same::value; - static constexpr bool partitioning_allowed = - !std::is_same::value; - static constexpr unsigned int max_small_segment_length - = config::warp_sort_config::items_per_thread_small - * config::warp_sort_config::logical_warp_size_small; - static constexpr unsigned int small_segments_per_block - = config::warp_sort_config::block_size_small - / config::warp_sort_config::logical_warp_size_small; - static constexpr unsigned int max_medium_segment_length - = config::warp_sort_config::items_per_thread_medium - * config::warp_sort_config::logical_warp_size_medium; - static constexpr unsigned int medium_segments_per_block - = config::warp_sort_config::block_size_medium - / config::warp_sort_config::logical_warp_size_medium; - static_assert( - max_small_segment_length <= max_medium_segment_length, - "The max length of small segments cannot be higher than the max length of medium segments"); - // Don't waste cycles on 3-way partitioning, if the small and medium segments are equal length - static constexpr bool three_way_partitioning - = max_small_segment_length < max_medium_segment_length; - using partitioner_type - = std::conditional_t; - partitioner_type partitioner; + const bool partitioning_allowed = params.warp_sort_config.partitioning_allowed; + const unsigned int max_small_segment_length = params.warp_sort_config.items_per_thread_small + * params.warp_sort_config.logical_warp_size_small; + const unsigned int small_segments_per_block = params.warp_sort_config.block_size_small + / params.warp_sort_config.logical_warp_size_small; + const unsigned int max_medium_segment_length + = params.warp_sort_config.items_per_thread_medium + * params.warp_sort_config.logical_warp_size_medium; + const unsigned int medium_segments_per_block + = params.warp_sort_config.block_size_medium + / params.warp_sort_config.logical_warp_size_medium; + + const bool three_way_partitioning = max_small_segment_length < max_medium_segment_length; + Partitioner partitioner(three_way_partitioning); const auto large_segment_selector = [=](const unsigned int segment_index) mutable -> bool { @@ -325,20 +357,22 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, const bool with_double_buffer = keys_tmp != nullptr; const unsigned int bits = end_bit - begin_bit; - const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, config::long_radix_bits); + const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, params.long_radix_bits); const bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; is_result_in_output = (iterations % 2 == 0) != to_output; - const unsigned int radix_bits_diff = config::long_radix_bits - config::short_radix_bits; - const unsigned int short_iterations = radix_bits_diff != 0 - ? ::rocprim::min(iterations, (config::long_radix_bits * iterations - bits) / radix_bits_diff) - : 0; + const unsigned int radix_bits_diff = params.long_radix_bits - params.short_radix_bits; + const unsigned int short_iterations + = radix_bits_diff != 0 + ? ::rocprim::min(iterations, + (params.long_radix_bits * iterations - bits) / radix_bits_diff) + : 0; const unsigned int long_iterations = iterations - short_iterations; - const bool do_partitioning = partitioning_allowed - && segments >= config::warp_sort_config::partitioning_threshold; + const bool do_partitioning + = partitioning_allowed && segments >= params.warp_sort_config.partitioning_threshold; - const size_t medium_segment_indices_size = three_way_partitioning ? segments : 0; - static constexpr size_t segment_count_output_size = three_way_partitioning ? 2 : 1; - const size_t segment_count_output_bytes + const size_t medium_segment_indices_size = three_way_partitioning ? segments : 0; + const size_t segment_count_output_size = three_way_partitioning ? 2 : 1; + const size_t segment_count_output_bytes = segment_count_output_size * sizeof(segment_index_type); segment_index_type* large_segment_indices_output{}; @@ -412,10 +446,14 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, std::cout << "long_iterations " << long_iterations << '\n'; std::cout << "short_iterations " << short_iterations << '\n'; std::cout << "do_partitioning " << do_partitioning << '\n'; - std::cout << "config::sort::block_size: " << config::sort::block_size << '\n'; - std::cout << "config::sort::items_per_thread: " << config::sort::items_per_thread << '\n'; + std::cout << "params.kernel_config.block_size: " << params.kernel_config.block_size << '\n'; + std::cout << "params.kernel_config.items_per_thread: " + << params.kernel_config.items_per_thread << '\n'; hipError_t error = hipStreamSynchronize(stream); - if(error != hipSuccess) return error; + if(error != hipSuccess) + { + return error; + } } if(!with_double_buffer) @@ -443,8 +481,9 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, { return result; } - segment_index_type segment_counts[segment_count_output_size]{}; - result = detail::memcpy_and_sync(&segment_counts, + std::vector segment_counts(segment_count_output_size, + segment_index_type{}); + result = detail::memcpy_and_sync(segment_counts.data(), segment_count_output, segment_count_output_bytes, hipMemcpyDeviceToHost, @@ -466,15 +505,25 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, { std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(segmented_sort_large_kernel), - dim3(large_segment_count), dim3(config::sort::block_size), 0, stream, - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, large_segment_indices_output, - begin_offsets, end_offsets, - long_iterations, short_iterations, - begin_bit, end_bit - ); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_large_kernel), + dim3(large_segment_count), + dim3(params.kernel_config.block_size), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + large_segment_indices_output, + begin_offsets, + end_offsets, + long_iterations, + short_iterations, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments", large_segment_count, start) @@ -486,29 +535,24 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - segmented_sort_small_or_medium_kernel< - select_warp_sort_helper_config_medium_t, - Descending, - config::warp_sort_config::block_size_medium>), - dim3(medium_segment_grid_size), - dim3(config::warp_sort_config::block_size_medium), - 0, - stream, - keys_input, - keys_tmp, - keys_output, - values_input, - values_tmp, - values_output, - is_result_in_output, - medium_segment_count, - medium_segment_indices_output, - begin_offsets, - end_offsets, - begin_bit, - end_bit); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_medium_kernel), + dim3(medium_segment_grid_size), + dim3(params.warp_sort_config.block_size_medium), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + is_result_in_output, + medium_segment_count, + medium_segment_indices_output, + begin_offsets, + end_offsets, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:medium_segments", medium_segment_count, start) @@ -519,29 +563,24 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, small_segments_per_block); std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - segmented_sort_small_or_medium_kernel< - select_warp_sort_helper_config_small_t, - Descending, - config::warp_sort_config::block_size_small>), - dim3(small_segment_grid_size), - dim3(config::warp_sort_config::block_size_small), - 0, - stream, - keys_input, - keys_tmp, - keys_output, - values_input, - values_tmp, - values_output, - is_result_in_output, - small_segment_count, - small_segment_indices_output, - begin_offsets, - end_offsets, - begin_bit, - end_bit); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_small_kernel), + dim3(small_segment_grid_size), + dim3(params.warp_sort_config.block_size_small), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + is_result_in_output, + small_segment_count, + small_segment_indices_output, + begin_offsets, + end_offsets, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:small_segments", small_segment_count, start) @@ -551,15 +590,24 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, { std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(segmented_sort_kernel), - dim3(segments), dim3(config::sort::block_size), 0, stream, - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, - begin_offsets, end_offsets, - long_iterations, short_iterations, - begin_bit, end_bit - ); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_kernel), + dim3(segments), + dim3(params.kernel_config.block_size), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offsets, + end_offsets, + long_iterations, + short_iterations, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start) } return hipSuccess; diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp index 7d5b248a8..62def5dff 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,88 +29,14 @@ #include "../functional.hpp" #include "config_types.hpp" +#include "detail/config/device_segmented_radix_sort.hpp" +#include "detail/device_config_helper.hpp" /// \addtogroup primitivesmodule_deviceconfigs /// @{ BEGIN_ROCPRIM_NAMESPACE -/// \brief Configuration of the warp sort part of the device segmented radix sort operation. -/// Short enough segments are processed on warp level. -/// -/// \tparam LogicalWarpSizeSmall - number of threads in the logical warp of the kernel -/// that processes small segments. -/// \tparam ItemsPerThreadSmall - number of items processed by a thread in the kernel that processes -/// small segments. -/// \tparam BlockSizeSmall - number of threads per block in the kernel which processes the small segments. -/// \tparam PartitioningThreshold - if the number of segments is at least this threshold, the -/// segments are partitioned to a small, a medium and a large segment collection. Both collections -/// are sorted by different kernels. Otherwise, all segments are sorted by a single kernel. -/// \tparam EnableUnpartitionedWarpSort - If set to \p true, warp sort can be used to sort -/// the small segments, even if the total number of segments is below \p PartitioningThreshold. -/// \tparam LogicalWarpSizeMedium - number of threads in the logical warp of the kernel -/// that processes medium segments. -/// \tparam ItemsPerThreadMedium - number of items processed by a thread in the kernel that processes -/// medium segments. -/// \tparam BlockSizeMedium - number of threads per block in the kernel which processes the medium segments. -template -struct WarpSortConfig -{ - static_assert(LogicalWarpSizeSmall * ItemsPerThreadSmall - <= LogicalWarpSizeMedium * ItemsPerThreadMedium, - "The number of items processed by a small warp cannot be larger than the number " - "of items processed by a medium warp"); - /// \brief The number of threads in the logical warp in the small segment processing kernel. - static constexpr unsigned int logical_warp_size_small = LogicalWarpSizeSmall; - /// \brief The number of items processed by a thread in the small segment processing kernel. - static constexpr unsigned int items_per_thread_small = ItemsPerThreadSmall; - /// \brief The number of threads per block in the small segment processing kernel. - static constexpr unsigned int block_size_small = BlockSizeSmall; - /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into - /// small and large segment groups, and each group is handled by a different, specialized kernel. - static constexpr unsigned int partitioning_threshold = PartitioningThreshold; - /// \brief If set to \p true, warp sort can be used to sort the small segments, even if the total number of - /// segments is below \p PartitioningThreshold. - static constexpr bool enable_unpartitioned_warp_sort = EnableUnpartitionedWarpSort; - /// \brief The number of threads in the logical warp in the medium segment processing kernel. - static constexpr unsigned int logical_warp_size_medium = LogicalWarpSizeMedium; - /// \brief The number of items processed by a thread in the medium segment processing kernel. - static constexpr unsigned int items_per_thread_medium = ItemsPerThreadMedium; - /// \brief The number of threads per block in the medium segment processing kernel. - static constexpr unsigned int block_size_medium = BlockSizeMedium; -}; - -/// \brief Indicates if the warp level sorting is disabled in the -/// device segmented radix sort configuration. -struct DisabledWarpSortConfig -{ - /// \brief The number of threads in the logical warp in the small segment processing kernel. - static constexpr unsigned int logical_warp_size_small = 1; - /// \brief The number of items processed by a thread in the small segment processing kernel. - static constexpr unsigned int items_per_thread_small = 1; - /// \brief The number of threads per block in the small segment processing kernel. - static constexpr unsigned int block_size_small = 1; - /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into - /// small and large segment groups, and each group is handled by a different, specialized kernel. - static constexpr unsigned int partitioning_threshold = 0; - /// \brief If set to \p true, warp sort can be used to sort the small segments, even if the total number of - /// segments is below \p PartitioningThreshold. - static constexpr bool enable_unpartitioned_warp_sort = false; - /// \brief The number of threads in the logical warp in the medium segment processing kernel. - static constexpr unsigned int logical_warp_size_medium = 1; - /// \brief The number of items processed by a thread in the medium segment processing kernel. - static constexpr unsigned int items_per_thread_medium = 1; - /// \brief The number of threads per block in the medium segment processing kernel. - static constexpr unsigned int block_size_medium = 1; -}; - /// \brief Selects the appropriate \p WarpSortConfig based on the size of the key type. /// /// \tparam Key - the type of the sorted keys. @@ -123,235 +49,54 @@ using select_warp_sort_config_t 4, //< items per thread - small kernel 256, //< block size - small kernel 3000, //< partitioning threshold - (sizeof(Key) > 2), //< enable unpartitioned warp sort MediumWarpSize, //< logical warp size - medium kernel 4, //< items per thread - medium kernel 256 //< block size - medium kernel >>; -/// \brief Configuration of device-level segmented radix sort operation. -/// -/// Radix sort is excecuted in a few iterations (passes) depending on total number of bits to be sorted -/// (\p begin_bit and \p end_bit), each iteration sorts either \p LongRadixBits or \p ShortRadixBits bits -/// choosen to cover whole bit range in optimal way. -/// -/// For example, if \p LongRadixBits is 7, \p ShortRadixBits is 6, \p begin_bit is 0 and \p end_bit is 32 -/// there will be 5 iterations: 7 + 7 + 6 + 6 + 6 = 32 bits. -/// -/// If a segment's element count is low ( <= warp_sort_config::items_per_thread * warp_sort_config::logical_warp_size ), -/// it is sorted by a special warp-level sorting method. -/// -/// \tparam LongRadixBits - number of bits in long iterations. -/// \tparam ShortRadixBits - number of bits in short iterations, must be equal to or less than \p LongRadixBits. -/// \tparam SortConfig - configuration of radix sort kernel. Must be \p kernel_config. -/// \tparam WarpSortConfig - configuration of the warp sort that is used on the short segments. -template< - unsigned int LongRadixBits, - unsigned int ShortRadixBits, - class SortConfig, - class WarpSortConfig = DisabledWarpSortConfig -> -struct segmented_radix_sort_config -{ - /// \brief Number of bits in long iterations. - static constexpr unsigned int long_radix_bits = LongRadixBits; - /// \brief Number of bits in short iterations - static constexpr unsigned int short_radix_bits = ShortRadixBits; - /// \brief Configuration of radix sort kernel. - using sort = SortConfig; - /// \brief Configuration of the warp sort method. - using warp_sort_config = WarpSortConfig; -}; - namespace detail { -template -struct segmented_radix_sort_config_803 +template +struct wrapped_segmented_radix_sort_config { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 13>, select_warp_sort_config_t > - >, - segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > - >; -}; - -template -struct segmented_radix_sort_config_803 - : select_type< - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > > - > { }; - -template -struct segmented_radix_sort_config_900 -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<4, 4, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > - >; -}; - -template -struct segmented_radix_sort_config_900 - : select_type< - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > > - > { }; - -template -struct segmented_radix_sort_config_90a -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<4, - 4, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<6, - 5, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, - 6, - kernel_config<256, 15>, - select_warp_sort_config_t>>, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, - 6, - kernel_config<256, 15>, - select_warp_sort_config_t>>, - segmented_radix_sort_config<7, - 6, - kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, - select_warp_sort_config_t>>; + static_assert(std::is_same::value, + "Config must be a specialization of struct template segmented_radix_sort_config"); + + template + struct architecture_config + { + static constexpr detail::segmented_radix_sort_config_params params + = SegmentedRadixSortConfig{}; + }; }; -template -struct segmented_radix_sort_config_90a - : select_type< - select_type_case< - sizeof(Key) == 1, - segmented_radix_sort_config<4, - 3, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - sizeof(Key) == 2, - segmented_radix_sort_config<6, - 5, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - sizeof(Key) == 4, - segmented_radix_sort_config<7, - 6, - kernel_config<256, 17>, - select_warp_sort_config_t>>, - select_type_case< - sizeof(Key) == 8, - segmented_radix_sort_config<7, - 6, - kernel_config<256, 15>, - select_warp_sort_config_t>>> -{}; - -template -struct segmented_radix_sort_config_1030 +template +struct wrapped_segmented_radix_sort_config { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<4, 4, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > - >; + template + struct architecture_config + { + static constexpr segmented_radix_sort_config_params params + = detail::default_segmented_radix_sort_config(Arch), + key_type, + value_type>{}; + }; }; -template -struct segmented_radix_sort_config_1030 - : select_type< - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > > - > { }; - -template -struct default_segmented_radix_sort_config - : select_arch< - TargetArch, - select_arch_case<803, detail::segmented_radix_sort_config_803>, - select_arch_case<900, detail::segmented_radix_sort_config_900>, - select_arch_case<906, detail::segmented_radix_sort_config_90a>, - select_arch_case<908, detail::segmented_radix_sort_config_90a>, - select_arch_case>, - select_arch_case<1030, detail::segmented_radix_sort_config_1030>, - detail::segmented_radix_sort_config_900> -{}; +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +template +constexpr segmented_radix_sort_config_params + wrapped_segmented_radix_sort_config:: + architecture_config::params; +template +template +constexpr segmented_radix_sort_config_params + wrapped_segmented_radix_sort_config:: + architecture_config::params; +#endif // DOXYGEN_SHOULD_SKIP_THIS } // end namespace detail diff --git a/rocprim/include/rocprim/device/device_segmented_reduce.hpp b/rocprim/include/rocprim/device/device_segmented_reduce.hpp index 424b291ec..d3e9a76a7 100644 --- a/rocprim/include/rocprim/device/device_segmented_reduce.hpp +++ b/rocprim/include/rocprim/device/device_segmented_reduce.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,12 +26,12 @@ #include #include "../config.hpp" -#include "../functional.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" +#include "../functional.hpp" #include "detail/config/device_reduce.hpp" #include "detail/device_segmented_reduce.hpp" +#include "rocprim/type_traits.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -100,9 +100,8 @@ hipError_t segmented_reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::detail::match_result_type< - input_type, BinaryFunction - >::type; + using result_type = + typename ::rocprim::invoke_result_binary_op::type; using config = wrapped_reduce_config; diff --git a/rocprim/include/rocprim/device/device_segmented_scan.hpp b/rocprim/include/rocprim/device/device_segmented_scan.hpp index 4cad6e0f4..e54564fbb 100644 --- a/rocprim/include/rocprim/device/device_segmented_scan.hpp +++ b/rocprim/include/rocprim/device/device_segmented_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,6 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" #include "../iterator/zip_iterator.hpp" #include "../iterator/discard_iterator.hpp" diff --git a/rocprim/include/rocprim/device/device_transform.hpp b/rocprim/include/rocprim/device/device_transform.hpp index a9de7d827..4238290a1 100644 --- a/rocprim/include/rocprim/device/device_transform.hpp +++ b/rocprim/include/rocprim/device/device_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,9 +28,8 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" -#include "../types/tuple.hpp" #include "../iterator/zip_iterator.hpp" +#include "../types/tuple.hpp" #include "device_transform_config.hpp" #include "detail/device_transform.hpp" @@ -142,7 +141,7 @@ inline hipError_t transform(InputIterator input, return hipSuccess; using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::detail::invoke_result::type; + using result_type = typename ::rocprim::invoke_result::type; using config = detail::wrapped_transform_config; diff --git a/rocprim/include/rocprim/device/device_transform_config.hpp b/rocprim/include/rocprim/device/device_transform_config.hpp index bf6c6c1a2..f350fe48a 100644 --- a/rocprim/include/rocprim/device/device_transform_config.hpp +++ b/rocprim/include/rocprim/device/device_transform_config.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail // device transform does not have config tuning template -struct default_transform_config : default_transform_config_base +struct default_transform_config : default_transform_config_base::type {}; template diff --git a/rocprim/include/rocprim/intrinsics/warp.hpp b/rocprim/include/rocprim/intrinsics/warp.hpp index ba192af87..9c1d1eded 100644 --- a/rocprim/include/rocprim/intrinsics/warp.hpp +++ b/rocprim/include/rocprim/intrinsics/warp.hpp @@ -116,9 +116,6 @@ int warp_all(int predicate) } // end detail namespace -/// @} -// end of group intrinsicsmodule - /// \brief Group active lanes having the same bits of \p label /// /// Threads that have the same least significant \p LabelBits bits are grouped into the same group. @@ -185,6 +182,9 @@ ROCPRIM_DEVICE ROCPRIM_INLINE bool group_elect(lane_mask_type mask) return prev_same_count == 0 && mask != 0; } +/// @} +// end of group intrinsicsmodule + END_ROCPRIM_NAMESPACE #endif // ROCPRIM_INTRINSICS_WARP_HPP_ diff --git a/rocprim/include/rocprim/iterator/transform_iterator.hpp b/rocprim/include/rocprim/iterator/transform_iterator.hpp index db3d9a000..e98b8f03d 100644 --- a/rocprim/include/rocprim/iterator/transform_iterator.hpp +++ b/rocprim/include/rocprim/iterator/transform_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,7 @@ #include #include "../config.hpp" -#include "../detail/match_result_type.hpp" +#include "../type_traits.hpp" /// \addtogroup iteratormodule /// @{ @@ -47,14 +47,11 @@ BEGIN_ROCPRIM_NAMESPACE /// \tparam UnaryFunction - type of the transform functor. /// \tparam ValueType - type of value that can be obtained by dereferencing the iterator. /// By default it is the return type of \p UnaryFunction. -template< - class InputIterator, - class UnaryFunction, - class ValueType = - typename ::rocprim::detail::invoke_result< - UnaryFunction, typename std::iterator_traits::value_type - >::type -> +template::value_type>::type> class transform_iterator { public: diff --git a/rocprim/include/rocprim/thread/thread_operators.hpp b/rocprim/include/rocprim/thread/thread_operators.hpp index 46f3d72a4..59ae5ac7b 100644 --- a/rocprim/include/rocprim/thread/thread_operators.hpp +++ b/rocprim/include/rocprim/thread/thread_operators.hpp @@ -1,7 +1,7 @@ /****************************************************************************** * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * Modifications Copyright (c) 2017-2024, Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -175,7 +175,7 @@ namespace detail // rocPRIM (as well as Thrust) uses result type of BinaryFunction instead (if not void): // // using input_type = typename std::iterator_traits::value_type; -// using result_type = typename ::rocprim::detail::match_result_type< +// using result_type = typename ::rocprim::match_result_type< // input_type, BinaryFunction // >::type; // diff --git a/rocprim/include/rocprim/type_traits.hpp b/rocprim/include/rocprim/type_traits.hpp index c616decfe..8a40d79f3 100644 --- a/rocprim/include/rocprim/type_traits.hpp +++ b/rocprim/include/rocprim/type_traits.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,12 +21,11 @@ #ifndef ROCPRIM_TYPE_TRAITS_HPP_ #define ROCPRIM_TYPE_TRAITS_HPP_ -#include - -// Meta configuration for rocPRIM #include "config.hpp" #include "types.hpp" +#include + /// \addtogroup utilsmodule_typetraits /// @{ @@ -197,6 +196,117 @@ auto TwiddleOut(UnsignedBits key) }; #endif // DOXYGEN_SHOULD_SKIP_THIS +namespace detail +{ + +// invoke_result is based on std::invoke_result. +// The main difference is using ROCPRIM_HOST_DEVICE, this allows to +// use invoke_result with device-only lambdas/functors in host-only functions +// on HIP-clang. + +template +struct is_reference_wrapper : std::false_type +{}; +template +struct is_reference_wrapper> : std::true_type +{}; + +template +struct invoke_impl +{ + template + ROCPRIM_HOST_DEVICE static auto call(F&& f, Args&&... args) + -> decltype(std::forward(f)(std::forward(args)...)); +}; + +template +struct invoke_impl +{ + template::type, + class = typename std::enable_if::value>::type> + ROCPRIM_HOST_DEVICE static auto get(T&& t) -> T&&; + + template::type, + class = typename std::enable_if::value>::type> + ROCPRIM_HOST_DEVICE static auto get(T&& t) -> decltype(t.get()); + + template::type, + class = typename std::enable_if::value>::type, + class = typename std::enable_if::value>::type> + ROCPRIM_HOST_DEVICE static auto get(T&& t) -> decltype(*std::forward(t)); + + template::value>::type> + ROCPRIM_HOST_DEVICE static auto call(MT1 B::*pmf, T&& t, Args&&... args) + -> decltype((invoke_impl::get(std::forward(t)).*pmf)(std::forward(args)...)); + + template + ROCPRIM_HOST_DEVICE static auto call(MT B::*pmd, T&& t) + -> decltype(invoke_impl::get(std::forward(t)).*pmd); +}; + +template::type> +ROCPRIM_HOST_DEVICE auto INVOKE(F&& f, Args&&... args) + -> decltype(invoke_impl::call(std::forward(f), std::forward(args)...)); + +// Conforming C++14 implementation (is also a valid C++11 implementation): +template +struct invoke_result_impl +{}; +template +struct invoke_result_impl(), std::declval()...))), + F, + Args...> +{ + using type = decltype(INVOKE(std::declval(), std::declval()...)); +}; + +} // end namespace detail + +/// \brief Behaves like ``std::invoke_result``, but allows the use of invoke_result +/// with device-only lambdas/functors in host-only functions on HIP-clang. +/// +/// \tparam F Type of the function. +/// \tparam Args Input type(s) to the function ``F``. +template +struct invoke_result : detail::invoke_result_impl +{ +#ifdef DOXYGEN_DOCUMENTATION_BUILD + /// \brief The return type of the Callable type F if invoked with the arguments Args. + /// \hideinitializer + using type = detail::invoke_result_impl::type; +#endif // DOXYGEN_DOCUMENTATION_BUILD +}; + +/// \brief Helper type. It is an alias for ``invoke_result::type``. +/// +/// \tparam F Type of the function. +/// \tparam Args Input type(s) to the function ``F``. +template +using invoke_result_t = typename invoke_result::type; + +/// \brief Utility wrapper around ``invoke_result`` for binary operators. +/// +/// \tparam T Input type to the binary operator. +/// \tparam F Type of the binary operator. +template +struct invoke_result_binary_op +{ + /// \brief The result type of the binary operator. + using type = typename invoke_result::type; +}; + +/// \brief Helper type. It is an alias for ``invoke_result_binary_op::type``. +/// +/// \tparam T Input type to the binary operator. +/// \tparam F Type of the binary operator. +template +using invoke_result_binary_op_t = typename invoke_result_binary_op::type; END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp index 9ce2350ba..f53458818 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -155,9 +155,19 @@ class warp_scan_dpp } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, T& reduction, - BinaryFunction scan_op) + ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + T input, T& output, storage_type& /*storage*/, T& reduction, BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize - 1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(output, output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + exclusive_scan(T input, T& output, T init, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); // Broadcast value from the last thread in warp @@ -234,8 +244,8 @@ class warp_scan_dpp } protected: - ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) { (void) storage; return to_exclusive(inclusive_input, exclusive_output); diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp index bedc99b5b..47736b065 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -104,6 +104,15 @@ class warp_scan_shared_mem to_exclusive(output, storage); } + template + ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + T input, T& output, storage_type& storage, T& reduction, BinaryFunction scan_op) + { + inclusive_scan(input, output, storage, scan_op); + reduction = storage.get().threads[WarpSize - 1]; + to_exclusive(output, storage); + } + template ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan(T input, T& output, T init, T& reduction, @@ -158,8 +167,8 @@ class warp_scan_shared_mem } protected: - ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) { (void) inclusive_input; return to_exclusive(exclusive_output, storage); diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp index 457297528..106888c4a 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -118,6 +118,17 @@ class warp_scan_shuffle to_exclusive(output, output); } + template + ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + T input, T& output, storage_type& /*storage*/, T& reduction, BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize - 1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(output, output); + } + template ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan(T input, T& output, T init, T& reduction, @@ -198,8 +209,8 @@ class warp_scan_shuffle } protected: - ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) { (void) storage; return to_exclusive(inclusive_input, exclusive_output); diff --git a/rocprim/include/rocprim/warp/warp_exchange.hpp b/rocprim/include/rocprim/warp/warp_exchange.hpp index 99581091b..6e54f9d3b 100644 --- a/rocprim/include/rocprim/warp/warp_exchange.hpp +++ b/rocprim/include/rocprim/warp/warp_exchange.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -210,8 +210,8 @@ class warp_exchange { const auto value = ::rocprim::warp_shuffle( input[src_idx], - flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread) - ); + flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread), + WarpSize); if(src_idx == flat_id % ItemsPerThread) { work_array[dst_idx] = value; @@ -328,10 +328,10 @@ class warp_exchange ROCPRIM_UNROLL for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) { - const auto value = ::rocprim::warp_shuffle( - input[src_idx], - (ItemsPerThread * flat_id + dst_idx) % WarpSize - ); + const auto value + = ::rocprim::warp_shuffle(input[src_idx], + (ItemsPerThread * flat_id + dst_idx) % WarpSize, + WarpSize); if(flat_id / (WarpSize / ItemsPerThread) == src_idx) { work_array[dst_idx] = value; diff --git a/rocprim/include/rocprim/warp/warp_scan.hpp b/rocprim/include/rocprim/warp/warp_scan.hpp index 116d10360..9e8738046 100644 --- a/rocprim/include/rocprim/warp/warp_scan.hpp +++ b/rocprim/include/rocprim/warp/warp_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -443,6 +443,80 @@ class warp_scan return; } + /// \brief Performs exclusive scan without an initial value across threads in a logical warp + /// \tparam BinaryFunction binary function used for scan + /// \param input Thread input value + /// \param[out] output Reference to thread output value. Each threads value for the scan will + /// be written to it. May be aliased with `input`. The value written is unspecified for the first + /// thread of each logical warp. + /// \param [in] storage Reference to a temporary storage object of type storage_type. + /// \param scan_op The function object used to combine elements used for the scan + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, + T& output, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) +#ifndef DOXYGEN_DOCUMENTATION_BUILD + -> std::enable_if_t +#else + -> void +#endif + { + base_type::exclusive_scan(input, output, storage, scan_op); + } + + /// \cond + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T /*input*/, + T& /*output*/, + storage_type& /*storage*/, + BinaryFunction /*scan_op*/ = BinaryFunction()) + -> std::enable_if_t<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE)> + { + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size." + " Aborting warp scan."); + } + /// \endcond + + /// \brief Performs exclusive scan and reduction without an initial value across threads in + /// a logical warp + /// \tparam BinaryFunction binary function used for scan + /// \param input Thread input value + /// \param[out] output Reference to thread output value. Each threads value for the scan will + /// be written to it. May be aliased with `input`. The value written is unspecified for the first + /// thread of each logical warp. + /// \param[out] reduction Result of reducing of all `input` values in the logical warp. + /// \param [in] storage Reference to a temporary storage object of type storage_type. + /// \param scan_op The function object used to combine elements used for the scan + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, + T& output, + storage_type& storage, + T& reduction, + BinaryFunction scan_op = BinaryFunction()) +#ifndef DOXYGEN_DOCUMENTATION_BUILD + -> std::enable_if_t +#else + -> void +#endif + { + base_type::exclusive_scan(input, output, storage, reduction, scan_op); + } + + /// \cond + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T /*input*/, + T& /*output*/, + storage_type& /*storage*/, + T& /*reduction*/, + BinaryFunction /*scan_op*/ = BinaryFunction()) + -> std::enable_if_t<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE)> + { + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size." + " Aborting warp scan."); + } + /// \endcond + /// \brief Performs inclusive and exclusive scan operations across threads /// in a logical warp. /// @@ -658,19 +732,18 @@ class warp_scan #ifndef DOXYGEN_SHOULD_SKIP_THIS protected: - + // These undocumented functions are used by hipCUB prior to version 3.1 template - ROCPRIM_DEVICE ROCPRIM_INLINE - auto to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) - -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE auto + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) -> + typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type { return base_type::to_exclusive(inclusive_input, exclusive_output, storage); } template - ROCPRIM_DEVICE ROCPRIM_INLINE - auto to_exclusive(T , T& , storage_type&) - -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE auto to_exclusive(T, T&, storage_type&) -> + typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type { ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); return; diff --git a/scripts/autotune/create_optimization.py b/scripts/autotune/create_optimization.py index 677e02f3e..57c2e0d66 100755 --- a/scripts/autotune/create_optimization.py +++ b/scripts/autotune/create_optimization.py @@ -475,6 +475,16 @@ class AlgorithmDeviceAdjacentDifferenceInplace(Algorithm): def __init__(self, fallback_entries): Algorithm.__init__(self, fallback_entries) +class AlgorithmDeviceSegmentedRadixSort(Algorithm): + algorithm_name = 'device_segmented_radix_sort' + cpp_configuration_template_name = 'segmented_radix_sort_config_template' + config_selection_params = [ + SelectionType(name='key_type', is_optional=False), + SelectionType(name='value_type', is_optional=True)] + + def __init__(self, fallback_entries): + Algorithm.__init__(self, fallback_entries) + def filt_algo_regex(e, algorithm_name): if 'algo_regex' in e: return re.match(e['algo_regex'], algorithm_name) is not None @@ -508,6 +518,8 @@ def create_algorithm(algorithm_name: str, fallback_entries): return AlgorithmDeviceAdjacentDifference(fallback_entries) elif algorithm_name == 'device_adjacent_difference_inplace': return AlgorithmDeviceAdjacentDifferenceInplace(fallback_entries) + elif algorithm_name == 'device_segmented_radix_sort': + return AlgorithmDeviceSegmentedRadixSort(fallback_entries) else: raise(NotSupportedError(f'Algorithm "{algorithm_name}" is not supported (yet)')) diff --git a/scripts/autotune/templates/adjacent_difference_config_template b/scripts/autotune/templates/adjacent_difference_config_template index f40ad24cd..353e86d1f 100644 --- a/scripts/autotune/templates/adjacent_difference_config_template +++ b/scripts/autotune/templates/adjacent_difference_config_template @@ -10,7 +10,7 @@ adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg'] {% macro general_case() -%} template -struct default_adjacent_difference_config : default_adjacent_difference_config_base +struct default_adjacent_difference_config : default_adjacent_difference_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/adjacent_difference_inplace_config_template b/scripts/autotune/templates/adjacent_difference_inplace_config_template index 1031bf5e0..3e5d7fa5c 100644 --- a/scripts/autotune/templates/adjacent_difference_inplace_config_template +++ b/scripts/autotune/templates/adjacent_difference_inplace_config_template @@ -10,7 +10,7 @@ adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg'] {% macro general_case() -%} template -struct default_adjacent_difference_inplace_config : default_adjacent_difference_config_base +struct default_adjacent_difference_inplace_config : default_adjacent_difference_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/histogram_config_template b/scripts/autotune/templates/histogram_config_template index 60aff0382..d492ab5c8 100644 --- a/scripts/autotune/templates/histogram_config_template +++ b/scripts/autotune/templates/histogram_config_template @@ -11,7 +11,7 @@ histogram_config struct default_histogram_config : -default_histogram_config_base { }; +default_histogram_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/reduce_config_template b/scripts/autotune/templates/reduce_config_template index 5630d58dc..3d3924cf3 100644 --- a/scripts/autotune/templates/reduce_config_template +++ b/scripts/autotune/templates/reduce_config_template @@ -10,7 +10,7 @@ reduce_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, : {% macro general_case() -%} template struct default_reduce_config : -default_reduce_config_base { }; +default_reduce_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/scan_config_template b/scripts/autotune/templates/scan_config_template index 02f4fafa5..6de5a83ad 100644 --- a/scripts/autotune/templates/scan_config_template +++ b/scripts/autotune/templates/scan_config_template @@ -10,7 +10,7 @@ scan_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::r {% macro general_case() -%} template struct default_scan_config : -default_scan_config_base { }; +default_scan_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/scanbykey_config_template b/scripts/autotune/templates/scanbykey_config_template index e17a89de7..72653bb61 100644 --- a/scripts/autotune/templates/scanbykey_config_template +++ b/scripts/autotune/templates/scanbykey_config_template @@ -10,7 +10,7 @@ scan_by_key_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] {% macro general_case() -%} template struct default_scan_by_key_config : -default_scan_by_key_config_base { }; +default_scan_by_key_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/segmented_radix_sort_config_template b/scripts/autotune/templates/segmented_radix_sort_config_template new file mode 100644 index 000000000..0187df79e --- /dev/null +++ b/scripts/autotune/templates/segmented_radix_sort_config_template @@ -0,0 +1,20 @@ +{% extends "config_template" %} + +{% macro get_header_guard() %} +ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ +{%- endmacro %} + +{% macro kernel_configuration(measurement) -%} +segmented_radix_sort_config<{{ measurement['cfg']['lrb'] }}, {{ measurement['cfg']['srb'] }}, {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, {{ measurement['cfg']['eupws'] }}, typename std::conditional<{{ measurement['cfg']['wsc']['pa'] }},WarpSortConfig<{{ measurement['cfg']['wsc']['lwss'] }},{{ measurement['cfg']['wsc']['ipts'] }},{{ measurement['cfg']['wsc']['bss'] }},{{ measurement['cfg']['wsc']['pt'] }},{{ measurement['cfg']['wsc']['lwsm'] }},{{ measurement['cfg']['wsc']['iptm'] }},{{ measurement['cfg']['wsc']['bsm'] }}>,DisabledWarpSortConfig>::type> { }; +{%- endmacro %} + +{% macro general_case() -%} +template +struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6, 4>::type +{}; +{%- endmacro %} + +{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} +// Based on {{ based_on_type }} +template struct default_segmented_radix_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{%- endmacro %} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b7cc9b83d..aed06265f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -54,7 +54,9 @@ function(add_hip_test TEST_NAME TEST_SOURCES) get_filename_component(TEST_TARGET ${TEST_MAIN_SOURCE} NAME_WE) add_executable(${TEST_TARGET} ${TEST_SOURCES}) - rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + if (ROCPRIM_INSTALL) + rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + endif() target_include_directories(${TEST_TARGET} SYSTEM BEFORE PUBLIC @@ -130,9 +132,11 @@ add_subdirectory(rocprim) add_hip_test("hipgraph.basic" hipgraph/test_hipgraph_basic.cpp) add_hip_test("hipgraph.algs" hipgraph/test_hipgraph_algs.cpp) -rocm_install( - FILES "${INSTALL_TEST_FILE}" - DESTINATION "${CMAKE_INSTALL_BINDIR}/${PROJECT_NAME}" - COMPONENT tests - RENAME "CTestTestfile.cmake" -) +if (ROCPRIM_INSTALL) + rocm_install( + FILES "${INSTALL_TEST_FILE}" + DESTINATION "${CMAKE_INSTALL_BINDIR}/${PROJECT_NAME}" + COMPONENT tests + RENAME "CTestTestfile.cmake" + ) +endif() diff --git a/test/common_test_header.hpp b/test/common_test_header.hpp index 8ac8fdf50..62f6afcfd 100755 --- a/test/common_test_header.hpp +++ b/test/common_test_header.hpp @@ -60,6 +60,12 @@ } #endif +#if(defined(__GNUC__) || defined(__clang__)) && (defined(__GLIBCXX__) || defined(_LIBCPP_VERSION)) + #define ROCPRIM_HAS_INT128_SUPPORT 1 +#else + #define ROCPRIM_HAS_INT128_SUPPORT 0 +#endif + #define INSTANTIATE_TYPED_TEST_EXPANDED_1(line, test_suite_name, ...) \ namespace Id##line \ { \ @@ -110,19 +116,27 @@ inline int obtain_device_from_ctest() #endif } +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning( \ + disable : 4996) // This function or variable may be unsafe. Consider using _dupenv_s instead. +#endif inline bool use_hmm() { - if (getenv("ROCPRIM_USE_HMM") == nullptr) + if(getenv("ROCPRIM_USE_HMM") == nullptr) { return false; } - if (strcmp(getenv("ROCPRIM_USE_HMM"), "1") == 0) + if(strcmp(getenv("ROCPRIM_USE_HMM"), "1") == 0) { return true; } return false; } +#ifdef _MSC_VER + #pragma warning(pop) +#endif // Helper for HMM allocations: HMM is requested through ROCPRIM_USE_HMM=1 environment variable template diff --git a/test/hipgraph/test_hipgraph_algs.cpp b/test/hipgraph/test_hipgraph_algs.cpp index a0811fd89..816cff1f4 100644 --- a/test/hipgraph/test_hipgraph_algs.cpp +++ b/test/hipgraph/test_hipgraph_algs.cpp @@ -52,7 +52,9 @@ void generate_needles(const std::vector& input, std::vector& o output[indices.size() + i] = out_of_bounds_vals[i]; // Mix up the in-bounds and out-of-bounds values to make the test a bit more robust - std::random_shuffle(output.begin(), output.end()); + std::random_device rd; + std::default_random_engine gen(rd()); + std::shuffle(output.begin(), output.end(), gen); } template diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index 3fbc675ce..6ab37e809 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,7 +36,9 @@ endfunction() function(add_rocprim_test_internal TEST_NAME TEST_SOURCES TEST_TARGET) add_executable(${TEST_TARGET} ${TEST_SOURCES}) - rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + if (ROCPRIM_INSTALL) + rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + endif() target_include_directories(${TEST_TARGET} SYSTEM BEFORE PUBLIC @@ -261,13 +263,14 @@ add_rocprim_test("rocprim.device_transform" test_device_transform.cpp) add_rocprim_test("rocprim.discard_iterator" test_discard_iterator.cpp) add_rocprim_test("rocprim.reverse_iterator" test_reverse_iterator.cpp) if(NOT USE_HIP_CPU) - add_rocprim_test("rocprim.texture_cache_iterator" test_texture_cache_iterator.cpp) +add_rocprim_test("rocprim.texture_cache_iterator" test_texture_cache_iterator.cpp) endif() add_rocprim_test("rocprim.thread" test_thread.cpp) add_rocprim_test("rocprim.thread_algos" test_thread_algos.cpp) add_rocprim_test("rocprim.transform_iterator" test_transform_iterator.cpp) add_rocprim_test("rocprim.no_half_operators" test_no_half_operators.cpp) add_rocprim_test("rocprim.intrinsics" test_intrinsics.cpp) +add_rocprim_test("rocprim.invoke_result" test_invoke_result.cpp) add_rocprim_test("rocprim.warp_exchange" test_warp_exchange.cpp) add_rocprim_test("rocprim.warp_load" test_warp_load.cpp) add_rocprim_test("rocprim.warp_reduce" test_warp_reduce.cpp) diff --git a/test/rocprim/indirect_iterator.hpp b/test/rocprim/indirect_iterator.hpp new file mode 100644 index 000000000..68328f7bb --- /dev/null +++ b/test/rocprim/indirect_iterator.hpp @@ -0,0 +1,179 @@ +// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef TEST_INDIRECT_ITERATOR_HPP_ +#define TEST_INDIRECT_ITERATOR_HPP_ + +#include +#include + +#include + +namespace test_utils +{ + +// assign-through reference_wrapper implementation +template +class reference_wrapper +{ +public: + // types + using type = T; + + // construct/copy/destroy + explicit constexpr reference_wrapper(T& t) : _ptr(&t) {} + + constexpr reference_wrapper(const reference_wrapper&) noexcept = default; + + // assignment + constexpr reference_wrapper& operator=(const T& x) noexcept + { + *_ptr = x; + return *this; + } + + // access + constexpr operator T&() const noexcept + { + return *_ptr; + } + constexpr T& get() const noexcept + { + return *_ptr; + } + +private: + T* _ptr; +}; + +// Iterator used in tests to check situtations when value_type of the +// iterator is not the same as the return type of operator[]. +// It is a simplified version of device_vector::iterator from thrust. +template +class indirect_iterator +{ +public: + // Iterator traits + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T*; + using reference = reference_wrapper; + + using iterator_category = std::random_access_iterator_tag; + + ROCPRIM_HOST_DEVICE inline indirect_iterator(T* ptr) : ptr_(ptr) {} + + ROCPRIM_HOST_DEVICE inline ~indirect_iterator() = default; + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator++() + { + ++ptr_; + return *this; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator++(int) + { + indirect_iterator old = *this; + ++ptr_; + return old; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator--() + { + --ptr_; + return *this; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator--(int) + { + indirect_iterator old = *this; + --ptr_; + return old; + } + + ROCPRIM_HOST_DEVICE inline reference operator*() const + { + return *ptr_; + } + + ROCPRIM_HOST_DEVICE inline reference operator[](difference_type n) const + { + return reference{*(ptr_ + n)}; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator+(difference_type distance) const + { + auto i = ptr_ + distance; + return indirect_iterator{i}; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator+=(difference_type distance) + { + ptr_ += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator-(difference_type distance) const + { + auto i = ptr_ - distance; + return indirect_iterator{i}; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator-=(difference_type distance) + { + ptr_ -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline difference_type operator-(indirect_iterator other) const + { + return ptr_ - other.ptr_; + } + + ROCPRIM_HOST_DEVICE inline bool operator==(indirect_iterator other) const + { + return ptr_ == other.ptr_; + } + + ROCPRIM_HOST_DEVICE inline bool operator!=(indirect_iterator other) const + { + return ptr_ != other.ptr_; + } + +private: + T* ptr_; +}; + +template +inline auto wrap_in_indirect_iterator(T* ptr) -> + typename std::enable_if>::type +{ + return indirect_iterator(ptr); +} + +template +inline auto wrap_in_indirect_iterator(T* ptr) -> typename std::enable_if::type +{ + return ptr; +} + +} // namespace test_utils + +#endif // TEST_INDIRECT_ITERATOR_HPP_ diff --git a/test/rocprim/test_device_adjacent_difference.cpp b/test/rocprim/test_device_adjacent_difference.cpp index 197a170f7..fa97eecb6 100644 --- a/test/rocprim/test_device_adjacent_difference.cpp +++ b/test/rocprim/test_device_adjacent_difference.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ #include "../common_test_header.hpp" +#include "indirect_iterator.hpp" #include "test_utils_types.hpp" #include @@ -37,70 +38,131 @@ namespace { -template -auto dispatch_adjacent_difference(std::true_type /*left*/, - std::false_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt output, - Args&&... args) +enum class api_variant +{ + no_alias, + alias, + in_place +}; + +std::string to_string(api_variant aliasing) +{ + switch(aliasing) + { + case api_variant::no_alias: return "no_alias"; + case api_variant::alias: return "alias"; + case api_variant::in_place: return "in_place"; + } +} + +template +auto dispatch_adjacent_difference( + std::true_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) { return ::rocprim::adjacent_difference( temporary_storage, storage_size, input, output, std::forward(args)...); } -template -auto dispatch_adjacent_difference(std::false_type /*left*/, - std::false_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt output, - Args&&... args) +template +auto dispatch_adjacent_difference( + std::false_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) { return ::rocprim::adjacent_difference_right( temporary_storage, storage_size, input, output, std::forward(args)...); } -template -auto dispatch_adjacent_difference(std::true_type /*left*/, - std::true_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt /*output*/, - Args&&... args) +template +auto dispatch_adjacent_difference( + std::true_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt /*output*/, + Args&&... args) { return ::rocprim::adjacent_difference_inplace( temporary_storage, storage_size, input, std::forward(args)...); } -template -auto dispatch_adjacent_difference(std::false_type /*left*/, - std::true_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt /*output*/, - Args&&... args) +template +auto dispatch_adjacent_difference( + std::false_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt /*output*/, + Args&&... args) { return ::rocprim::adjacent_difference_right_inplace( temporary_storage, storage_size, input, std::forward(args)...); } +template +auto dispatch_adjacent_difference( + std::true_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) +{ + return ::rocprim::adjacent_difference_inplace(temporary_storage, + storage_size, + input, + output, + std::forward(args)...); +} + +template +auto dispatch_adjacent_difference( + std::false_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) +{ + return ::rocprim::adjacent_difference_right_inplace(temporary_storage, + storage_size, + input, + output, + std::forward(args)...); +} + template auto get_expected_result(const std::vector& input, const BinaryFunction op, @@ -125,35 +187,38 @@ auto get_expected_result(const std::vector& input, // Params for tests template + class OutputType = InputType, + bool Left = true, + api_variant Aliasing = api_variant::no_alias, + bool UseIdentityIterator = false, + class Config = rocprim::default_config, + bool UseGraphs = false, + bool UseIndirectIterator = false> struct DeviceAdjacentDifferenceParams { - using input_type = InputType; - using output_type = OutputType; - static constexpr bool left = Left; - static constexpr bool in_place = InPlace; - static constexpr bool use_identity_iterator = UseIdentityIterator; - using config = Config; - static constexpr bool use_graphs = UseGraphs; + using input_type = InputType; + using output_type = OutputType; + static constexpr bool left = Left; + static constexpr api_variant aliasing = Aliasing; + static constexpr bool use_identity_iterator = UseIdentityIterator; + using config = Config; + static constexpr bool use_graphs = UseGraphs; + static constexpr bool use_indirect_iterator = UseIndirectIterator; }; template class RocprimDeviceAdjacentDifferenceTests : public ::testing::Test { public: - using input_type = typename Params::input_type; - using output_type = typename Params::output_type; - static constexpr bool left = Params::left; - static constexpr bool in_place = Params::in_place; - static constexpr bool use_identity_iterator = Params::use_identity_iterator; - static constexpr bool debug_synchronous = false; - using config = typename Params::config; - static constexpr bool use_graphs = Params::use_graphs; + using input_type = typename Params::input_type; + using output_type = typename Params::output_type; + static constexpr bool left = Params::left; + static constexpr api_variant aliasing = Params::aliasing; + static constexpr bool use_identity_iterator = Params::use_identity_iterator; + static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; + static constexpr bool debug_synchronous = false; + using config = typename Params::config; + static constexpr bool use_graphs = Params::use_graphs; }; using custom_double2 = test_utils::custom_test_type; @@ -173,19 +238,60 @@ using RocprimDeviceAdjacentDifferenceTestsParams = ::testing::Types< // Tests with default configuration DeviceAdjacentDifferenceParams, DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + // this is changed to not use identity iterator + // because the function doesn't work with it, should be changed back, when fixed + DeviceAdjacentDifferenceParams, // Tests for supported config structs - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, // Tests for different size_limits - DeviceAdjacentDifferenceParams>, - DeviceAdjacentDifferenceParams>, - DeviceAdjacentDifferenceParams>, - DeviceAdjacentDifferenceParams>; + DeviceAdjacentDifferenceParams>, + DeviceAdjacentDifferenceParams, + false, + true>, + DeviceAdjacentDifferenceParams, + false, + true>, + DeviceAdjacentDifferenceParams>; TYPED_TEST_SUITE(RocprimDeviceAdjacentDifferenceTests, RocprimDeviceAdjacentDifferenceTestsParams); @@ -195,15 +301,16 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - using T = typename TestFixture::input_type; - using output_type = typename TestFixture::output_type; - static constexpr bool left = TestFixture::left; - static constexpr bool in_place = TestFixture::in_place; - static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; - static constexpr bool debug_synchronous = TestFixture::debug_synchronous; - using Config = typename TestFixture::config; + using T = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + static constexpr bool left = TestFixture::left; + static constexpr api_variant aliasing = TestFixture::aliasing; + const bool debug_synchronous = TestFixture::debug_synchronous; + static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; + using Config = typename TestFixture::config; - SCOPED_TRACE(testing::Message() << "left = " << left << ", in_place = " << in_place); + SCOPED_TRACE(testing::Message() + << "left = " << left << ", api_variant = " << to_string(aliasing)); for(std::size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -223,123 +330,149 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - const std::vector input = test_utils::get_random_data(size, 1, 100, seed_value); - std::vector output(input.size()); + std::vector input = test_utils::get_random_data(size, 1, 100, seed_value); - T* d_input; - output_type* d_output = nullptr; + T* d_input; HIP_CHECK( test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(input[0]))); HIP_CHECK(hipMemcpy( d_input, input.data(), input.size() * sizeof(input[0]), hipMemcpyHostToDevice)); - if(!in_place) - { - HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, - output.size() * sizeof(output[0]))); - } + static constexpr auto left_tag = rocprim::detail::bool_constant{}; + static constexpr auto alias_tag = std::integral_constant{}; - static constexpr auto left_tag = rocprim::detail::bool_constant {}; - static constexpr auto in_place_tag = rocprim::detail::bool_constant {}; + auto input_it + = test_utils::wrap_in_indirect_iterator( + d_input); - // Calculate expected results on host - const auto expected - = get_expected_result(input, rocprim::minus<> {}, left_tag); - - const auto output_it - = test_utils::wrap_in_identity_iterator(d_output); - - hipGraph_t graph; + hipGraph_t graph; hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) graph = test_utils::createGraphHelper(stream); - + // Allocate temporary storage std::size_t temp_storage_size; void* d_temp_storage = nullptr; HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, + alias_tag, d_temp_storage, temp_storage_size, - d_input, - output_it, + input_it, + (output_type*){nullptr}, size, - rocprim::minus<> {}, + rocprim::minus<>{}, stream, - TestFixture::debug_synchronous)); + debug_synchronous)); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + ASSERT_GT(temp_storage_size, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - - // Run - HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, - d_temp_storage, - temp_storage_size, - d_input, - output_it, - size, - rocprim::minus<> {}, - stream, - TestFixture::debug_synchronous)); - HIP_CHECK(hipGetLastError()); + // We might call the API multiple times, with almost the same parameter + // (in-place and out-of-place) + // we should be able to use the same amount of temp storage for and get the same + // results (maybe with different types) for both. + auto run_and_verify = [&](const auto output_it, auto* d_output) + { + // Run + HIP_CHECK(dispatch_adjacent_difference(left_tag, + alias_tag, + d_temp_storage, + temp_storage_size, + input_it, + output_it, + size, + rocprim::minus<>{}, + stream, + debug_synchronous)); + HIP_CHECK(hipGetLastError()); + + if(TestFixture::use_graphs) + { + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } + + // input_type for in-place, output_type for out of place + using current_output_type = std::remove_reference_t; + + // allocate memory for output + std::vector output(size); + + // Copy output to host + HIP_CHECK(hipMemcpy(output.data(), + d_output, + output.size() * sizeof(output[0]), + hipMemcpyDeviceToHost)); + + // Calculate expected results on host + const auto expected + = get_expected_result(input, rocprim::minus<>{}, left_tag); + + // Check if output values are as expected + test_utils::assert_near( + output, + expected, + std::max(test_utils::precision, test_utils::precision)); + }; + + // if api_variant is not in_place we should check the non aliased function call + if(aliasing != api_variant::in_place) + { + output_type* d_output = nullptr; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, size * sizeof(*d_output))); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + if(TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); - // Copy output to host - HIP_CHECK( - hipMemcpy(output.data(), - in_place ? static_cast(d_input) : static_cast(d_output), - output.size() * sizeof(output[0]), - hipMemcpyDeviceToHost)); + const auto output_it + = test_utils::wrap_in_identity_iterator(d_output); - // Check if output values are as expected - ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( - output, - expected, - std::max(test_utils::precision, test_utils::precision))); + ASSERT_NO_FATAL_FAILURE(run_and_verify(output_it, d_output)); - hipFree(d_input); - if(!in_place) - { hipFree(d_output); } - hipFree(d_temp_storage); - if (TestFixture::use_graphs) + // if api_variant is not no_alias we should check the inplace function call + if(aliasing != api_variant::no_alias) + { + if(TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + + ASSERT_NO_FATAL_FAILURE(run_and_verify(input_it, d_input)); + } + + if(TestFixture::use_graphs) { test_utils::cleanupGraphHelper(graph, graph_instance); HIP_CHECK(hipStreamDestroy(stream)); } + + hipFree(d_temp_storage); + hipFree(d_input); } } } // Params for tests -template +template struct DeviceAdjacentDifferenceLargeParams { - static constexpr bool left = Left; - static constexpr bool in_place = InPlace; - static constexpr bool use_graphs = UseGraphs; + static constexpr bool left = Left; + static constexpr api_variant aliasing = Aliasing; + static constexpr bool use_graphs = UseGraphs; }; template class RocprimDeviceAdjacentDifferenceLargeTests : public ::testing::Test { public: - static constexpr bool left = Params::left; - static constexpr bool in_place = Params::in_place; - static constexpr bool debug_synchronous = false; - static constexpr bool use_graphs = Params::use_graphs; + static constexpr bool left = Params::left; + static constexpr api_variant aliasing = Params::aliasing; + static constexpr bool debug_synchronous = false; + static constexpr bool use_graphs = Params::use_graphs; }; template @@ -451,9 +584,10 @@ class check_output_iterator }; using RocprimDeviceAdjacentDifferenceLargeTestsParams - = ::testing::Types, - DeviceAdjacentDifferenceLargeParams, - DeviceAdjacentDifferenceLargeParams>; + = ::testing::Types, + DeviceAdjacentDifferenceLargeParams, + DeviceAdjacentDifferenceLargeParams, + DeviceAdjacentDifferenceLargeParams>; TYPED_TEST_SUITE(RocprimDeviceAdjacentDifferenceLargeTests, RocprimDeviceAdjacentDifferenceLargeTestsParams); @@ -467,13 +601,14 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) using T = size_t; static constexpr bool is_left = TestFixture::left; - static constexpr bool is_in_place = TestFixture::in_place; + static constexpr api_variant aliasing = TestFixture::aliasing; + const bool debug_synchronous = TestFixture::debug_synchronous; static constexpr unsigned int sampling_rate = 10000; using OutputIterator = check_output_iterator; using flag_type = OutputIterator::flag_type; SCOPED_TRACE(testing::Message() - << "is_left = " << is_left << ", is_in_place = " << is_in_place); + << "is_left = " << is_left << ", api_variant = " << to_string(aliasing)); hipStream_t stream = 0; // default if (TestFixture::use_graphs) @@ -514,7 +649,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) { return (smaller_value + larger_value) / 2 + (is_left ? 1 : 0); }; static constexpr auto left_tag = rocprim::detail::bool_constant{}; - static constexpr auto in_place_tag = rocprim::detail::bool_constant{}; + static constexpr auto aliasing_tag = std::integral_constant{}; hipGraph_t graph; hipGraphExec_t graph_instance; @@ -525,7 +660,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) std::size_t temp_storage_size; void* d_temp_storage = nullptr; HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, + aliasing_tag, d_temp_storage, temp_storage_size, input, @@ -533,7 +668,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) size, op, stream, - TestFixture::debug_synchronous)); + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -547,7 +682,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) // Run HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, + aliasing_tag, d_temp_storage, temp_storage_size, input, @@ -555,7 +690,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) size, op, stream, - TestFixture::debug_synchronous)); + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); diff --git a/test/rocprim/test_device_radix_sort.cpp.in b/test/rocprim/test_device_radix_sort.cpp.in index 4d0931246..65d34f6a6 100644 --- a/test/rocprim/test_device_radix_sort.cpp.in +++ b/test/rocprim/test_device_radix_sort.cpp.in @@ -50,7 +50,7 @@ #endif #if ROCPRIM_TEST_TYPE_SLICE == 0 -#if defined(__GNUC__) || defined(__clang__) +#if ROCPRIM_HAS_INT128_SUPPORT INSTANTIATE(params<__int128_t, __int128_t>) INSTANTIATE(params<__uint128_t, __uint128_t>) #endif diff --git a/test/rocprim/test_device_scan.cpp b/test/rocprim/test_device_scan.cpp index 8511ec6d6..9f2bdac22 100644 --- a/test/rocprim/test_device_scan.cpp +++ b/test/rocprim/test_device_scan.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -150,12 +150,12 @@ TYPED_TEST_SUITE(RocprimDeviceScanTests, RocprimDeviceScanTestsParams); TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) { - using T = typename TestFixture::input_type; - using U = typename TestFixture::output_type; + using T = typename TestFixture::input_type; + using U = typename TestFixture::output_type; using scan_op_type = typename TestFixture::scan_op_type; // if scan_op_type is rocprim::plus and input_type is bfloat16 or half, // use float as device-side accumulator and double as host-side accumulator - using acc_type = typename accum_type::type; + using acc_type = typename accum_type::type; const bool debug_synchronous = TestFixture::debug_synchronous; int device_id = test_common_utils::obtain_device_from_ctest(); @@ -182,9 +182,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) // scan function scan_op_type scan_op; - auto input_iterator = rocprim::make_transform_iterator( - rocprim::make_constant_iterator(T(345)), - [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(rocprim::make_constant_iterator(T(345)), + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -317,8 +317,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) expected.begin(), scan_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -329,13 +330,15 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::inclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - input.size(), scan_op, stream, TestFixture::debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -351,13 +354,15 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::inclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - input.size(), scan_op, stream, TestFixture::debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -474,8 +479,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) scan_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -486,13 +492,16 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - initial_value, input.size(), scan_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + initial_value, + input.size(), + scan_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -508,13 +517,16 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - initial_value, input.size(), scan_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + initial_value, + input.size(), + scan_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -646,8 +658,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) scan_op, keys_compare_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -658,12 +671,16 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::inclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -679,12 +696,16 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::inclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -820,8 +841,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) scan_op, keys_compare_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -832,12 +854,17 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::exclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, initial_value, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + initial_value, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -853,12 +880,17 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::exclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, initial_value, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + initial_value, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -1630,8 +1662,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) = rocprim::future_value>{ future_iter}; - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -1643,7 +1676,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) char* d_temp_storage = nullptr; // Get size of d_temp_storage HIP_CHECK(rocprim::exclusive_scan( - nullptr, temp_storage_size_bytes, input_iterator, + nullptr, + temp_storage_size_bytes, + input_iterator, test_utils::wrap_in_identity_iterator(d_output), future_initial_value, input.size(), @@ -1686,7 +1721,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) // Run HIP_CHECK(rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, + d_temp_storage, + temp_storage_size_bytes, + input_iterator, test_utils::wrap_in_identity_iterator(d_output), future_initial_value, input.size(), diff --git a/test/rocprim/test_device_segmented_radix_sort.hpp b/test/rocprim/test_device_segmented_radix_sort.hpp index d75f068ed..34e462a62 100644 --- a/test/rocprim/test_device_segmented_radix_sort.hpp +++ b/test/rocprim/test_device_segmented_radix_sort.hpp @@ -53,44 +53,47 @@ struct params using config = Config; }; -using config_default = rocprim::segmented_radix_sort_config< - 4, //< long radix bits - 3, //< short radix bits - rocprim::kernel_config<256, 4> //< sort block size, items per thread - >; - -using config_semi_custom = rocprim::segmented_radix_sort_config< - 3, //< long radix bits - 2, //< short radix bits - rocprim::kernel_config<128, 4>, //< sort block size, items per thread - rocprim::WarpSortConfig<16, //< logical warp size small - 8 //< items per thread small - >>; - -using config_semi_custom_warp_config = rocprim::segmented_radix_sort_config< - 3, //< long radix bits - 2, //< short radix bits - rocprim::kernel_config<128, 4>, //< sort block size, items per thread - rocprim::WarpSortConfig<16, //< logical warp size small - 2, //< items per thread small - 512, //< block size small - 0, //< partitioning threshold - true //< enable unpartitioned sort - >>; - -using config_custom = rocprim::segmented_radix_sort_config< - 3, //< long radix bits - 2, //< short radix bits - rocprim::kernel_config<128, 4>, //< sort block size, items per thread - rocprim::WarpSortConfig<16, //< logical warp size small - 2, //< items per thread small - 512, //< block size small - 0, //< partitioning threshold - true, //< enable unpartitioned sort - 32, //< logical warp size medium - 4, //< items per thread medium - 256 //< block size medium - >>; +using config_default = rocprim::segmented_radix_sort_config<4, //< long radix bits + 3, //< short radix bits + 256, //< sort block size, + 4 //< items per thread + >; + +using config_semi_custom + = rocprim::segmented_radix_sort_config<3, //< long radix bits + 2, //< short radix bits + 128, //< sort block size + 4, //< items per thread + false, //< enable unpartitioned sort + rocprim::WarpSortConfig<16, //< logical warp size small + 8 //< items per thread small + >>; + +using config_semi_custom_warp_config + = rocprim::segmented_radix_sort_config<3, //< long radix bits + 2, //< short radix bits + 128, //< sort block size + 4, //< items per thread + true, //< enable unpartitioned sort + rocprim::WarpSortConfig<16, //< logical warp size small + 2, //< items per thread small + 512, //< block size small + 0>>; //< partitioning threshold + +using config_custom + = rocprim::segmented_radix_sort_config<3, //< long radix bits + 2, //< short radix bits + 128, //< sort block size + 4, //< items per thread + true, //< enable unpartitioned sort + rocprim::WarpSortConfig<16, //< logical warp size small + 2, //< items per thread small + 512, //< block size small + 0, //< partitioning threshold + 32, //< logical warp size medium + 4, //< items per thread medium + 256 //< block size medium + >>; template class RocprimDeviceSegmentedRadixSort : public ::testing::Test diff --git a/test/rocprim/test_device_select.cpp b/test/rocprim/test_device_select.cpp index b903d5b4d..88d5562b4 100644 --- a/test/rocprim/test_device_select.cpp +++ b/test/rocprim/test_device_select.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -977,7 +977,7 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) if (TestFixture::use_graphs) test_utils::resetGraphHelper(graph, graph_instance, stream); - + // Run HIP_CHECK( rocprim::unique_by_key( @@ -997,7 +997,7 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) if (TestFixture::use_graphs) graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -1050,6 +1050,200 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) HIP_CHECK(hipStreamDestroy(stream)); } +TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + // This test checks correctness of in-place unique_by_key (so input keys and values iterators + // are passed as output iterators as well) + using key_type = typename TestFixture::key_type; + using value_type = typename TestFixture::value_type; + using output_key_type = key_type; + using output_value_type = value_type; + + using op_type = rocprim::equal_to; + + using scan_op_type = rocprim::plus; + static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; + const bool debug_synchronous = TestFixture::debug_synchronous; + + hipStream_t stream = 0; // default stream + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + const auto probabilities = get_discontinuity_probabilities(); + for(auto size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + for(auto p : probabilities) + { + SCOPED_TRACE(testing::Message() << "with p = " << p); + + // Generate data + std::vector input_keys(size); + { + std::vector input01 + = test_utils::get_random_data01(size, p, seed_value); + std::partial_sum(input01.begin(), + input01.end(), + input_keys.begin(), + scan_op_type()); + } + const auto input_values + = test_utils::get_random_data(size, -1000, 1000, seed_value); + + // Allocate and copy to device + key_type* d_keys_input; + value_type* d_values_input; + unsigned int* d_selected_count_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_keys_input, + input_keys.size() * sizeof(input_keys[0]))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_values_input, + input_values.size() + * sizeof(input_values[0]))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_selected_count_output, + sizeof(unsigned int))); + HIP_CHECK(hipMemcpy(d_keys_input, + input_keys.data(), + input_keys.size() * sizeof(input_keys[0]), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_values_input, + input_values.data(), + input_values.size() * sizeof(input_values[0]), + hipMemcpyHostToDevice)); + HIP_CHECK(hipDeviceSynchronize()); + + // Calculate expected results on host + std::vector expected_keys; + std::vector expected_values; + expected_keys.reserve(input_keys.size()); + expected_values.reserve(input_values.size()); + if(size > 0) + { + expected_keys.push_back(input_keys[0]); + expected_values.push_back(input_values[0]); + for(size_t i = 1; i < input_keys.size(); i++) + { + if(!op_type()(input_keys[i - 1], input_keys[i])) + { + expected_keys.push_back(input_keys[i]); + expected_values.push_back(input_values[i]); + } + } + } + + hipGraph_t graph; + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + + // temp storage + size_t temp_storage_size_bytes; + // Get size of d_temp_storage + HIP_CHECK(rocprim::unique_by_key( + nullptr, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + test_utils::wrap_in_identity_iterator(d_keys_input), + test_utils::wrap_in_identity_iterator(d_values_input), + d_selected_count_output, + input_keys.size(), + op_type(), + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + graph_instance = graph_instance + = test_utils::endCaptureGraphHelper(graph, stream, true, false); + + HIP_CHECK(hipDeviceSynchronize()); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + void* d_temp_storage = nullptr; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + if(TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + + // Run + HIP_CHECK(rocprim::unique_by_key( + d_temp_storage, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + test_utils::wrap_in_identity_iterator(d_keys_input), + test_utils::wrap_in_identity_iterator(d_values_input), + d_selected_count_output, + input_keys.size(), + op_type(), + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + graph_instance = graph_instance + = test_utils::endCaptureGraphHelper(graph, stream, true, false); + + HIP_CHECK(hipDeviceSynchronize()); + + // Check if number of selected value is as expected + unsigned int selected_count_output = 0; + HIP_CHECK(hipMemcpy(&selected_count_output, + d_selected_count_output, + sizeof(unsigned int), + hipMemcpyDeviceToHost)); + HIP_CHECK(hipDeviceSynchronize()); + ASSERT_EQ(selected_count_output, expected_keys.size()); + + // Check if outputs are as expected + std::vector output_keys(input_keys.size()); + HIP_CHECK(hipMemcpy(output_keys.data(), + d_keys_input, + output_keys.size() * sizeof(output_keys[0]), + hipMemcpyDeviceToHost)); + std::vector output_values(input_values.size()); + HIP_CHECK(hipMemcpy(output_values.data(), + d_values_input, + output_values.size() * sizeof(output_values[0]), + hipMemcpyDeviceToHost)); + HIP_CHECK(hipDeviceSynchronize()); + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_eq(output_keys, expected_keys, expected_keys.size())); + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_eq(output_values, expected_values, expected_values.size())); + + hipFree(d_keys_input); + hipFree(d_values_input); + hipFree(d_selected_count_output); + hipFree(d_temp_storage); + + if(TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + } + } + } + + if(TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + class RocprimDeviceSelectLargeInputTests : public ::testing::TestWithParam> { public: const bool debug_synchronous = false; diff --git a/test/rocprim/test_invoke_result.cpp b/test/rocprim/test_invoke_result.cpp new file mode 100644 index 000000000..c956426c8 --- /dev/null +++ b/test/rocprim/test_invoke_result.cpp @@ -0,0 +1,123 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "../common_test_header.hpp" +#include "rocprim/types.hpp" + +#include +#include + +#include +#include + +template +struct device_plus +{ + __device__ inline constexpr T operator()(const T& a, const T& b) const + { + return a + b; + } +}; + +template +struct RocprimTypeInvokeResultParams +{ + using input_type = InputType; + using function = Function; + using expected_type = ExpectedType; +}; + +template +class RocprimInvokeResultBinOpTests : public ::testing::Test +{ +public: + using input_type = typename Params::input_type; + using function = typename Params::function; + using expected_type = typename Params::expected_type; +}; + +typedef ::testing::Types< + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams, bool>, + RocprimTypeInvokeResultParams, bool>, + RocprimTypeInvokeResultParams, bool>> + RocprimInvokeResultBinOpTestsParams; + +TYPED_TEST_SUITE(RocprimInvokeResultBinOpTests, RocprimInvokeResultBinOpTestsParams); + +TYPED_TEST(RocprimInvokeResultBinOpTests, HostInvokeResult) +{ + using input_type = typename TestFixture::input_type; + using binary_function = typename TestFixture::function; + using expected_type = typename TestFixture::expected_type; + + using resulting_type = rocprim::invoke_result_binary_op_t; + + // Compile and check on host + static_assert(std::is_same::value, + "Resulting type is not equal to expected type!"); +} + +template +struct static_cast_op +{ + __device__ inline constexpr ToType operator()(FromType a) const + { + return static_cast(a); + } +}; + +template +class RocprimInvokeResultUnOpTests : public ::testing::Test +{ +public: + using input_type = typename Params::input_type; + using function = typename Params::function; + using expected_type = typename Params::expected_type; +}; + +typedef ::testing::Types< + RocprimTypeInvokeResultParams, float>, + RocprimTypeInvokeResultParams, + rocprim::bfloat16>, + RocprimTypeInvokeResultParams>> + RocprimInvokeResultUnOpTestsParams; + +TYPED_TEST_SUITE(RocprimInvokeResultUnOpTests, RocprimInvokeResultUnOpTestsParams); + +TYPED_TEST(RocprimInvokeResultUnOpTests, HostInvokeResult) +{ + using input_type = typename TestFixture::input_type; + using unary_function = typename TestFixture::function; + using expected_type = typename TestFixture::expected_type; + + using resulting_type = rocprim::invoke_result_t; + + static_assert(std::is_same::value, + "Resulting type is not equal to expected type!"); +} diff --git a/test/rocprim/test_utils.hpp b/test/rocprim/test_utils.hpp index 13dea15b6..f5010e8eb 100644 --- a/test/rocprim/test_utils.hpp +++ b/test/rocprim/test_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,7 +21,6 @@ #ifndef TEST_TEST_UTILS_HPP_ #define TEST_TEST_UTILS_HPP_ -#include #include #include #include @@ -446,12 +445,8 @@ void iota(ForwardIt first, ForwardIt last, T value) } template -struct DeviceSelectWarpSize -{ - static constexpr unsigned value = ::rocprim::device_warp_size() >= LogicalWarpSize - ? LogicalWarpSize - : ::rocprim::device_warp_size(); -}; +__device__ constexpr bool device_test_enabled_for_warp_size_v + = ::rocprim::device_warp_size() >= LogicalWarpSize; } // end test_utils namespace diff --git a/test/rocprim/test_utils_assertions.hpp b/test/rocprim/test_utils_assertions.hpp index 45b65583e..073b783c9 100644 --- a/test/rocprim/test_utils_assertions.hpp +++ b/test/rocprim/test_utils_assertions.hpp @@ -245,8 +245,7 @@ void assert_bit_eq(const std::vector& result, const std::vector& expected) } } } - -#if defined(__GNUC__) || defined(__clang__) +#if ROCPRIM_HAS_INT128_SUPPORT inline void assert_bit_eq(const std::vector<__int128_t>& result, const std::vector<__int128_t>& expected) { diff --git a/test/rocprim/test_utils_types.hpp b/test/rocprim/test_utils_types.hpp index 1f793cbda..fb901adc9 100644 --- a/test/rocprim/test_utils_types.hpp +++ b/test/rocprim/test_utils_types.hpp @@ -158,9 +158,13 @@ typedef ::testing::Types< typedef ::testing::Types), block_param_type(uint8_t, short), - block_param_type(int8_t, float), + block_param_type(int8_t, float) +#if ROCPRIM_HAS_INT128_SUPPORT + , block_param_type(__uint128_t, short), - block_param_type(__int128_t, float)> + block_param_type(__int128_t, float) +#endif + > BlockParamsIntegralExtended; typedef ::testing::Types< diff --git a/test/rocprim/test_warp_exchange.cpp b/test/rocprim/test_warp_exchange.cpp index 769d6764a..0906b3681 100644 --- a/test/rocprim/test_warp_exchange.cpp +++ b/test/rocprim/test_warp_exchange.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,12 +25,13 @@ #include -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize, - class ExchangeOp -> +#include +#include +#include + +#include + +template struct Params { using type = T; @@ -128,70 +129,67 @@ struct ScatterToStripedOp } }; -using WarpExchangeTestParams = ::testing::Types< - Params, - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params ->; - -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op -> -__global__ -__launch_bounds__(BlockSize) -void warp_exchange_kernel(T* d_input, - T* d_output) +using WarpExchangeTestParams + = ::testing::Types, + Params, + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + Params>; + +template +__device__ auto warp_exchange_test(T* d_input, T* d_output) + -> std::enable_if_t> { - static_assert(BlockSize == LogicalWarpSize, - "BlockSize must be equal to LogicalWarpSize in this test"); - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value - >; - - ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage; + using warp_exchange_type = ::rocprim::warp_exchange; + constexpr unsigned int num_warps = ::rocprim::device_warp_size() / LogicalWarpSize; + ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; i++) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; } - Op{}(warp_exchange_type(), thread_data, storage); + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; + Op{}(warp_exchange_type(), thread_data, storage[warp_id]); for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } +template +__device__ auto warp_exchange_test(T* /*d_input*/, T* /*d_output*/) + -> std::enable_if_t> +{} + +template +__global__ void warp_exchange_kernel(T* d_input, T* d_output) +{ + warp_exchange_test(d_input, d_output); +} + template std::vector stripe_vector(const std::vector& v, const size_t warp_size, @@ -217,13 +215,16 @@ TYPED_TEST(WarpExchangeTest, WarpExchange) using T = typename TestFixture::params::type; constexpr unsigned int warp_size = TestFixture::params::warp_size; constexpr unsigned int items_per_thread = TestFixture::params::items_per_thread; - using exchange_op = typename TestFixture::params::exchange_op; - constexpr unsigned int block_size = warp_size; - constexpr unsigned int items_count = items_per_thread * block_size; + using exchange_op = typename TestFixture::params::exchange_op; - int device_id = test_common_utils::obtain_device_from_ctest(); + const int device_id = test_common_utils::obtain_device_from_ctest(); SKIP_IF_UNSUPPORTED_WARP_SIZE(warp_size, device_id); + unsigned int hw_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, hw_warp_size)); + const unsigned int block_size = hw_warp_size; + const unsigned int items_count = items_per_thread * block_size; + std::vector input(items_count); std::iota(input.begin(), input.end(), static_cast(0)); auto expected = input; @@ -238,20 +239,10 @@ TYPED_TEST(WarpExchangeTest, WarpExchange) HIP_CHECK(hipMemcpy(d_input, input.data(), items_count * sizeof(T), hipMemcpyHostToDevice)); T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); + HIP_CHECK(hipMemset(d_output, 0, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_exchange_kernel< - T, - block_size, - items_per_thread, - warp_size, - exchange_op - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output - ); + warp_exchange_kernel + <<>>(d_input, d_output); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -270,13 +261,12 @@ TYPED_TEST(WarpExchangeTest, WarpExchange) ASSERT_EQ(expected, output); } -using WarpExchangeScatterTestParams = ::testing::Types< - Params, - Params, - Params, - Params, - Params - >; +using WarpExchangeScatterTestParams = ::testing::Types, + Params, + Params, + Params, + Params, + Params>; template class WarpExchangeScatterTest : public ::testing::Test @@ -285,46 +275,46 @@ class WarpExchangeScatterTest : public ::testing::Test using params = Params; }; -template< - class T, - class OffsetT, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op -> -__global__ -__launch_bounds__(BlockSize) -void warp_exchange_scatter_kernel(T* d_input, - T* d_output, - OffsetT* d_ranks) +template +__device__ auto warp_exchange_scatter_test(T* d_input, T* d_output, OffsetT* d_ranks) + -> std::enable_if_t> { - static_assert(BlockSize == LogicalWarpSize, - "BlockSize must be equal to LogicalWarpSize in this test"); - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value - >; + using warp_exchange_type = ::rocprim::warp_exchange; - ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage; + constexpr unsigned int num_warps = ::rocprim::device_warp_size() / LogicalWarpSize; + ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; OffsetT thread_ranks[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; i++) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; - thread_ranks[i] = d_ranks[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; + thread_ranks[i] = d_ranks[threadIdx.x * ItemsPerThread + i]; } - Op{}(warp_exchange_type(), thread_data, thread_ranks, storage); + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; + warp_exchange_type{}.scatter_to_striped(thread_data, + thread_data, + thread_ranks, + storage[warp_id]); for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } +template +__device__ auto warp_exchange_scatter_test(T* /*d_input*/, T* /*d_output*/, OffsetT* /*d_ranks*/) + -> std::enable_if_t> +{} + +template +__global__ void warp_exchange_scatter_kernel(T* d_input, T* d_output, OffsetT* d_ranks) +{ + warp_exchange_scatter_test(d_input, d_output, d_ranks); +} + TYPED_TEST_SUITE(WarpExchangeScatterTest, WarpExchangeScatterTestParams); TYPED_TEST(WarpExchangeScatterTest, WarpExchangeScatter) @@ -332,43 +322,42 @@ TYPED_TEST(WarpExchangeScatterTest, WarpExchangeScatter) using T = typename TestFixture::params::type; constexpr unsigned int warp_size = TestFixture::params::warp_size; constexpr unsigned int items_per_thread = TestFixture::params::items_per_thread; - using exchange_op = typename TestFixture::params::exchange_op; - constexpr unsigned int block_size = warp_size; - constexpr unsigned int items_count = items_per_thread * block_size; using OffsetT = unsigned short; - int device_id = test_common_utils::obtain_device_from_ctest(); + const int device_id = test_common_utils::obtain_device_from_ctest(); SKIP_IF_UNSUPPORTED_WARP_SIZE(warp_size, device_id); + unsigned int hw_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, hw_warp_size)); + const unsigned int block_size = hw_warp_size; + const unsigned int items_count = items_per_thread * block_size; std::vector input(items_count); std::iota(input.begin(), input.end(), static_cast(0)); - auto expected = input; - std::shuffle(input.begin(), input.end(), std::default_random_engine{std::random_device{}()}); - std::vector ranks(input.begin(), input.end()); + const auto expected = stripe_vector(input, warp_size, items_per_thread); + std::default_random_engine prng(std::random_device{}()); + for(auto it = input.begin(), end = input.end(); it != end; it += warp_size) + { + std::shuffle(it, it + warp_size, prng); + } + std::vector ranks(items_count); + std::transform(input.begin(), + input.end(), + ranks.begin(), + [](const T input_val) + { return static_cast(input_val) % (warp_size * items_per_thread); }); T* d_input{}; HIP_CHECK(hipMalloc(&d_input, items_count * sizeof(T))); HIP_CHECK(hipMemcpy(d_input, input.data(), items_count * sizeof(T), hipMemcpyHostToDevice)); T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); + HIP_CHECK(hipMemset(d_output, 0, items_count * sizeof(T))); OffsetT* d_ranks{}; HIP_CHECK(hipMalloc(&d_ranks, items_count * sizeof(OffsetT))); HIP_CHECK(hipMemcpy(d_ranks, ranks.data(), items_count * sizeof(OffsetT), hipMemcpyHostToDevice)); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_exchange_scatter_kernel< - T, - OffsetT, - block_size, - items_per_thread, - warp_size, - exchange_op - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output, d_ranks - ); + warp_exchange_scatter_kernel + <<>>(d_input, d_output, d_ranks); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -379,10 +368,5 @@ TYPED_TEST(WarpExchangeScatterTest, WarpExchangeScatter) HIP_CHECK(hipFree(d_output)); HIP_CHECK(hipFree(d_ranks)); - if(std::is_same::value) - { - expected = stripe_vector(expected, warp_size, items_per_thread); - } - ASSERT_EQ(expected, output); -} \ No newline at end of file +} diff --git a/test/rocprim/test_warp_load.cpp b/test/rocprim/test_warp_load.cpp index e1a14cfa0..c00e0501d 100644 --- a/test/rocprim/test_warp_load.cpp +++ b/test/rocprim/test_warp_load.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ #include "test_utils.hpp" #include +#include template< class T, @@ -78,29 +79,20 @@ using WarpLoadTestParams = ::testing::Types< Params >; -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_load_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_load_kernel(T* d_input, - T* d_output) +template +__device__ auto warp_load_test(T* d_input, T* d_output) + -> std::enable_if_t> { static_assert(BlockSize % LogicalWarpSize == 0, "LogicalWarpSize must be a divisor of BlockSize"); - using warp_load_type = ::rocprim::warp_load< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + using warp_load_type = ::rocprim::warp_load; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_load_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; @@ -109,35 +101,43 @@ void warp_load_kernel(T* d_input, for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_load_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_load_guarded_kernel(T* d_input, - T* d_output, - int valid_items, - T oob_default) +template +__device__ auto warp_load_test(T* /*d_input*/, T* /*d_output*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_load_kernel(T* d_input, T* d_output) +{ + warp_load_test(d_input, d_output); +} + +template +__device__ auto warp_load_guarded_test(T* d_input, T* d_output, int valid_items, T oob_default) + -> std::enable_if_t> { static_assert(BlockSize % LogicalWarpSize == 0, "LogicalWarpSize must be a divisor of BlockSize"); - using warp_load_type = ::rocprim::warp_load< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + using warp_load_type = ::rocprim::warp_load; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_load_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; @@ -152,10 +152,36 @@ void warp_load_guarded_kernel(T* d_input, for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } +template +__device__ auto + warp_load_guarded_test(T* /*d_input*/, T* /*d_output*/, int /*valid_items*/, T /*oob_default*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_load_guarded_kernel(T* d_input, + T* d_output, + int valid_items, + T oob_default) +{ + warp_load_guarded_test(d_input, + d_output, + valid_items, + oob_default); +} + template std::vector stripe_vector(const std::vector& v, const size_t warp_size, @@ -197,19 +223,8 @@ TYPED_TEST(WarpLoadTest, WarpLoad) T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_load_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output - ); + warp_load_kernel + <<>>(d_input, d_output); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -251,20 +266,8 @@ TYPED_TEST(WarpLoadTest, WarpLoadGuarded) T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_load_guarded_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output, - valid_items, oob_default - ); + warp_load_guarded_kernel + <<>>(d_input, d_output, valid_items, oob_default); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); diff --git a/test/rocprim/test_warp_scan.hpp b/test/rocprim/test_warp_scan.hpp index 7eb807f23..99d82f33c 100644 --- a/test/rocprim/test_warp_scan.hpp +++ b/test/rocprim/test_warp_scan.hpp @@ -417,6 +417,149 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) } +typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScanWoInit) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = typename TestFixture::params::type; + // for bfloat16 and half we use double for host-side accumulation + using binary_op_type_host = typename test_utils::select_plus_operator_host::type; + binary_op_type_host binary_op_host; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; + using cast_type = typename test_utils::select_plus_operator_host::cast_type; + + // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + static constexpr size_t logical_warp_size = TestFixture::params::warp_size; + + // The different warp sizes + static constexpr size_t ws32{ROCPRIM_WARP_SIZE_32}; + static constexpr size_t ws64{ROCPRIM_WARP_SIZE_64}; + + // Block size of warp size 32 + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; + + // Block size of warp size 64 + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; + + unsigned int current_device_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); + + const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; + const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; + + // Check if warp size is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + { + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%d. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); + GTEST_SKIP(); + } + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + // Generate data + std::vector input = test_utils::get_random_data(size, 2, 50, seed_value); + std::vector output(size); + std::vector expected(input.size(), T(0)); + + // Calculate expected results on host + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + // expected[i * logical_warp_size] is unspecified because init is not passed + acc_type accumulator(input[i * logical_warp_size]); + + static_assert(logical_warp_size > 2, "logical_warp_size assumed to be at least 2."); + expected[i * logical_warp_size + 1] = static_cast(accumulator); + + for(size_t j = 2; j < logical_warp_size; j++) + { + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx - 1], accumulator); + expected[idx] = static_cast(accumulator); + } + } + + // Writing to device memory + T* device_input; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_input, + input.size() * sizeof(typename decltype(input)::value_type))); + T* device_output; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_output, + output.size() * sizeof(typename decltype(output)::value_type))); + + HIP_CHECK( + hipMemcpy(device_input, input.data(), input.size() * sizeof(T), hipMemcpyHostToDevice)); + + // Launching kernel + if(current_device_warp_size == ws32) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + warp_exclusive_scan_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws32), + 0, + 0, + device_input, + device_output); + } + else if(current_device_warp_size == ws64) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + warp_exclusive_scan_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws64), + 0, + 0, + device_input, + device_output); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Read from device memory + HIP_CHECK(hipMemcpy(output.data(), + device_output, + output.size() * sizeof(T), + hipMemcpyDeviceToHost)); + + // The first value of each logical warp has an unspecified result, expect whatever we got + // for those values to not fail the test. + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + expected[i * logical_warp_size] = output[i * logical_warp_size]; + } + + // Validating results + test_utils::assert_near(output, expected, test_utils::precision * logical_warp_size); + + HIP_CHECK(hipFree(device_input)); + HIP_CHECK(hipFree(device_output)); + } +} + //typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) { @@ -571,6 +714,175 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) } +typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScanWoInit) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = typename TestFixture::params::type; + // for bfloat16 and half we use double for host-side accumulation + using binary_op_type_host = typename test_utils::select_plus_operator_host::type; + binary_op_type_host binary_op_host; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; + using cast_type = typename test_utils::select_plus_operator_host::cast_type; + + // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + static constexpr size_t logical_warp_size = TestFixture::params::warp_size; + + // The different warp sizes + static constexpr size_t ws32 = size_t(ROCPRIM_WARP_SIZE_32); + static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); + + // Block size of warp size 32 + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; + + // Block size of warp size 64 + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; + + unsigned int current_device_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); + + const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; + const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; + + // Check if warp size is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + { + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%d. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); + GTEST_SKIP(); + } + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + // Generate data + std::vector input = test_utils::get_random_data(size, 2, 50, seed_value); + std::vector output(size); + std::vector output_reductions(size / logical_warp_size); + std::vector expected(input.size(), T(0)); + std::vector expected_reductions(output_reductions.size(), T(0)); + + // Calculate expected results on host + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + // expected[i * logical_warp_size] is unspecified because init is not passed + acc_type accumulator(input[i * logical_warp_size]); + + static_assert(logical_warp_size > 2, "logical_warp_size assumed to be at least 2."); + expected[i * logical_warp_size + 1] = static_cast(accumulator); + + for(size_t j = 2; j < logical_warp_size; j++) + { + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx - 1], accumulator); + expected[idx] = static_cast(accumulator); + } + + acc_type accumulator_reductions(0); + for(size_t j = 0; j < logical_warp_size; j++) + { + auto idx = i * logical_warp_size + j; + accumulator_reductions = binary_op_host(input[idx], accumulator_reductions); + expected_reductions[i] = static_cast(accumulator_reductions); + } + } + + // Writing to device memory + T* device_input; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_input, + input.size() * sizeof(typename decltype(input)::value_type))); + T* device_output; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_output, + output.size() * sizeof(typename decltype(output)::value_type))); + T* device_output_reductions; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_output_reductions, + output_reductions.size() * sizeof(typename decltype(output_reductions)::value_type))); + + HIP_CHECK( + hipMemcpy(device_input, input.data(), input.size() * sizeof(T), hipMemcpyHostToDevice)); + + // Launching kernel + if(current_device_warp_size == ws32) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(warp_exclusive_scan_reduce_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws32), + 0, + 0, + device_input, + device_output, + device_output_reductions); + } + else if(current_device_warp_size == ws64) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(warp_exclusive_scan_reduce_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws64), + 0, + 0, + device_input, + device_output, + device_output_reductions); + } + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Read from device memory + HIP_CHECK(hipMemcpy(output.data(), + device_output, + output.size() * sizeof(T), + hipMemcpyDeviceToHost)); + + HIP_CHECK(hipMemcpy(output_reductions.data(), + device_output_reductions, + output_reductions.size() * sizeof(T), + hipMemcpyDeviceToHost)); + + // The first value of each logical warp has an unspecified result, expect whatever we got + // for those values to not fail the test. + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + expected[i * logical_warp_size] = output[i * logical_warp_size]; + } + + // Validating results + test_utils::assert_near(output, expected, test_utils::precision * logical_warp_size); + test_utils::assert_near(output_reductions, + expected_reductions, + test_utils::precision * logical_warp_size); + + HIP_CHECK(hipFree(device_input)); + HIP_CHECK(hipFree(device_output)); + HIP_CHECK(hipFree(device_output_reductions)); + } +} + typed_test_def(RocprimWarpScanTests, name_suffix, Scan) { int device_id = test_common_utils::obtain_device_from_ctest(); diff --git a/test/rocprim/test_warp_scan.kernels.hpp b/test/rocprim/test_warp_scan.kernels.hpp index 11f859ec6..875389058 100644 --- a/test/rocprim/test_warp_scan.kernels.hpp +++ b/test/rocprim/test_warp_scan.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -128,6 +128,49 @@ void warp_exclusive_scan_reduce_kernel( } } +template +__global__ __launch_bounds__(BlockSize) void warp_exclusive_scan_wo_init_kernel(T* device_input, + T* device_output) +{ + static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; + + const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); + const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; + + T value = device_input[global_index]; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[block_warps_no]; + wscan_t().exclusive_scan(value, value, storage[block_warp_id]); + + device_output[global_index] = value; +} + +template +__global__ __launch_bounds__(BlockSize) void warp_exclusive_scan_reduce_wo_init_kernel( + T* device_input, T* device_output, T* device_output_reductions) +{ + static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; + + const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); + const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; + const unsigned int lane_id = threadIdx.x % LogicalWarpSize; + const unsigned int global_warp_id = global_index / LogicalWarpSize; + + T value = device_input[global_index]; + T reduction; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[block_warps_no]; + wscan_t().exclusive_scan(value, value, storage[block_warp_id], reduction); + + device_output[global_index] = value; + if(lane_id == 0) + { + device_output_reductions[global_warp_id] = reduction; + } +} + template< class T, unsigned int BlockSize, diff --git a/test/rocprim/test_warp_store.cpp b/test/rocprim/test_warp_store.cpp index 2d6a0430a..63aab5355 100644 --- a/test/rocprim/test_warp_store.cpp +++ b/test/rocprim/test_warp_store.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ #include "test_utils.hpp" #include +#include template< class T, @@ -78,66 +79,66 @@ using WarpStoreTestParams = ::testing::Types< Params >; -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_store_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_store_kernel(T* d_input, - T* d_output) +template +__device__ auto warp_store_test(T* d_input, T* d_output) + -> std::enable_if_t> { - using warp_store_type = ::rocprim::warp_store< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + using warp_store_type = ::rocprim::warp_store; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_store_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; ++i) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; } warp_store_type().store(d_output + warp_id * tile_size, thread_data, storage[warp_id]); } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_store_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_store_guarded_kernel(T* d_input, - T* d_output, - int valid_items) +template +__device__ auto warp_store_test(T* /*d_input*/, T* /*d_output*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_store_kernel(T* d_input, T* d_output) { - using warp_store_type = ::rocprim::warp_store< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + warp_store_test(d_input, d_output); +} + +template +__device__ auto warp_store_guarded_test(T* d_input, T* d_output, int valid_items) + -> std::enable_if_t> +{ + using warp_store_type = ::rocprim::warp_store; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_store_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; ++i) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; } warp_store_type().store(d_output + warp_id * tile_size, @@ -147,6 +148,29 @@ void warp_store_guarded_kernel(T* d_input, ); } +template +__device__ auto warp_store_guarded_test(T* /*d_input*/, T* /*d_output*/, int /*valid_items*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_store_guarded_kernel(T* d_input, + T* d_output, + int valid_items) +{ + warp_store_guarded_test(d_input, + d_output, + valid_items); +} + template std::vector stripe_vector(const std::vector& v, const size_t warp_size, @@ -187,19 +211,8 @@ TYPED_TEST(WarpStoreTest, WarpLoad) T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_store_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output - ); + warp_store_kernel + <<>>(d_input, d_output); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -241,19 +254,8 @@ TYPED_TEST(WarpStoreTest, WarpStoreGuarded) HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); HIP_CHECK(hipMemset(d_output, 0, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_store_guarded_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output, valid_items - ); + warp_store_guarded_kernel + <<>>(d_input, d_output, valid_items); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize());