From 013fb2ca7ae0c08afbfc84beb2018dd2c572a348 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc=20Serf=C5=91z=C5=91?= Date: Tue, 19 Mar 2024 20:59:47 +0100 Subject: [PATCH] StreamHPC 2024-01-16 (#509) * Config tuning and dynamic dispatch for device segmented radix sort * add script to autotune:build job, to remove executables if the size is too large for artifact * fix ci script * fix typo * autotune:execute-tuning checks if build artifact has executables * fix incrementing in script * fix checking the executables * fix review comments * fix ci script * fix printing error by usage using multiline yaml * test ci script * let's try the backslash * this should work now * move autotune:execute-tuning to use multiline yaml for script * docs:(partition_two_way): add partition_two_way to sphinx * docs(batch_memcpy): add batch_memcpy to sphinx * docs(intrinsics): add match_any and group_elect to sphinx * docs(memcpy): improve consistency with other pages * docs: Migrate to using rocm-docs-core with the extension config * docs: Declare TOCs in _toc.yml.in This fixes the warnings given by sphinx_external_toc. Be explicit and add toc `tableofcontents` directives where the TOCs should be inserted. See https://github.com/executablebooks/sphinx-external-toc#add-a-toc-to-a-pages-content for more info. * Add memcpy to the summary of operations * Add exclusive_scan interfaces without initial value to warp_scan CUB has these, therefore hipCUB needs them too. Currently these are being worked around in hipCUB by using undocumented APIs of rocPRIM (`to_exclusive`) See for example [warp_scan.hpp:107 in hipCUB](https://github.com/ROCmSoftwarePlatform/hipCUB/blob/f459480f78164328214b75b16ffef338f1d4bc89/hipcub/include/hipcub/backend/rocprim/warp/warp_scan.hpp#L107) * Add tests for warp exclusive_scan without initial value The tests are based on the current tests modified to skip checking the first value of each warp. * Add warp_scan::exclusive_scan overloads wo initial value to CHANGELOG * Add too large logical warp runtime errors to the new warp_scan function * Improve documentation of warp_scan::exclusive_scan wo initial value - Add the no initial value part to the brief description - Hide the `enable_if` and the overloads required for runtime errors from the docs. * fix: Fixed doxygen warning in device_config_helper.hpp * style: Minor edits to warp_exclusive_scan wo init * fix(test_device_adjacent_difference.cpp): fixed unused variable warning * Consistent doxygen parameters in the new excl. scan APIs * Fix, simplify warp_id & lane_id variables in warp_exclusive_scan wo init * Fix MSVC warning due to depricated getenv * Fix formatting in test_warp_scan.hpp * clang-format * Fix linker issues for test debug compilation * Use shuffle instead of shuffle_random in hipgraph test * style(device_scan.hpp): use chevron-style ('<<<...>>>') kernel launching * fix(device_scan.hpp): derive the intermediate accumulator type from the scan operator instead of the initial value/input type This reduces the number of type conversions and makes the accuracy of the operation directly dependent on the chosen binary operator. * fix(device_scan_by_key.hpp): derive the intermediate accumulator type from the scan operator instead of the initial value/input type * test(test_device_scan.cpp): scan tests derive accumulator type from output of operator on device and host * test(test_device_scan.cpp): scan by key tests derive accumulator type from output of operator on device and host * Disable __int128_t tests on platforms without support * docs(device_scan.hpp): reflect device scan accumulator changes in documentation * fix(device_scan.hpp): use rocprim::detail::match_result_type instead of std::result_of This fixes compile time errors where the resulting type cannot be derived from device-only lambdas and functors. * fix(device_scan.hpp,device_scan_by_key.hpp): revert default accumulator type in scan algorithms back to using result type * revert(test_device_scan.cpp): scan by key tests derive accumulator type from output of operator on device and host This reverts commit 5bb5b1648fbefb303e42e65925c8afd59589426a. * feat(device_scan.hpp): added an optional type parameter for the accumulator type in scan algorithms By default the accumulator type is based on the scan operator. This is the intended behaviour for hipCUB, but rocThrust still bases this on the value type of the input iterator. To accomodate for both requirements, the accumulator type had to be exposed. * revert(test_device_scan.cpp): scan tests derive accumulator type from output of operator on device and host This reverts commit 0d2d4820ce7f87f55d6f725b85a9ef3e63d8f0a4. * docs(device_scan.hpp,device_scan_by_key.hpp): updated the documentation to include the optional type parameter for the accumulator in scan algorithms * revert(device_scan.hpp): reflect device scan accumulator changes in documentation This reverts commit 9d0ab0e58fe8954fef4b9d2ad154c71f6f6aaeb1. * style: improve formatting * docs(changelog.md): reflect changes to intermediate type in changelog * docs(changelog.md): update the changelog to include the addition of the optional accumulator type in scan algorithms * fix(device_scan.hpp,device_scan_by_key.hpp): use initial value for accumulator for exclusive scan * docs(device_scan.hpp,device_scan_by_key.hpp): update accumulator type parameter documentation * style: update copyright * style(test_device_scan.cpp): fix formatting * feat(type_traits.hpp): expose 'invoke_result' and 'invoke_result_binary_op' These were previously internal functions. * removed unused code warnings in benchmarks and added warning compiler flags to gitlab ci for benchmarks * docs: improve documentation * style: update copyright and fix style * Formating and copyright data changes * refactor(match_result_type.hpp,type_traits.hpp): moved implementation details of invoke_result to type_traits.hpp * test(test_type_traits.cpp): added tests for 'invoke_result' * refactor(test_invoke_result.cpp): rename from 'test_type_traits' to 'test_invoke_result' * style: update copyright dates * test(test_invoke_result.cpp): test also cover device-only functions * feat(type_traits.hpp): add c++17-styled aliases for 'invoke_result' and 'invoke_result_binary_op' * style: update style * Removed redundant inheritance in device templates * refactor(test_invoke_result.cpp): use fixed-width integer types * Fixed linting * declare the return type of lambda used in adjacent difference, to avoid compile errors * Fixed warp_exchange blocked_to_striped_shuffle and striped_to_blocked_shuffle The logical warp size was not passed to the shuffle operation, therefore only the first logical warp in the block was executed. * add new api calls for device_adjacent_difference * Improved warp_exchange test suite Multiple logical warps are executed per test. Added tests with 2 and 8 byte value types. * rename new api function to avoid overload and make things clearer * Updated copyright dates * fix call not changed in test * Updated changelog * fix review comments update docs comments refactor test_device_adjacent_difference in_place part * add test cases for device_adjacent_difference to check for input iterators not returning value_type for operator[] * fix format check large index test * fix format and merge errors * fix review comments rename to indirect iterator simplify indirect iterator * change adjacent_difference_alias to adjacent_difference_inplace * fix review comments have separate code paths for non aliased and in place calls in tests documentation updates * Fix unique_by_key to allow input and output values iterators aliasing * fix rocm 6.0 compilation errors add api variant loggin to other test * Help compiler optimize unused value_type * refactor(deatail/various.hpp): Perfect forwarding for foreach_in_tuple Instead of taking l-value references use std::forward to forward value type to the passed function. For example allows foreach_in_tuple to be called on const tuples. * build(CMakeLists.txt): Skip packaging when project isn't toplevel Skip packaging when we're being added as a sub-project (for example using FetchContent). Only a single project can use `rocm_create_package()` we don't want to trump over whoever is depending on us. This should fix the warnings like "rocm_package_add_deb_dependencies called after rocm_create_package!" in hipCUB (and probably rocThrust). * CHANGELOG and copyright updates * refactor(detaul/temp_storage.hpp): Make layout() const on partitions * build(cmake): Add cmake option to disable installation Default to ON for backward compatibility * refactor: Further improve foreach_in_tuple - Use an array instead of an (implicit) initializer_list. - Be consistent with the template parameter name * build(CHANGELOG.md, CMakeLists.txt): Set version number in CMake * docs: Fix some documentation warnings/errors * docs: Fixed changelog style * refactor(test_device_adjacent_difference): Reduce code duplication Simplify code by factoring out common parts of in-place and out-of-place tests. * docs: Fixed SPHINX_DIR * refactor(warp_scan): Deprecate undocumented to_exclusive APIs These were used by prior versions of CUB, but now have public replacements. * Add initial Windows CI * Substituted DeviceSelectWarpSize with device_test_enabled_for_warp_size_v * Documentation fix after rebase --------- Co-authored-by: Bence Parajdi Co-authored-by: Nara Prasetya Co-authored-by: Gergely Meszaros Co-authored-by: Balint Soproni Co-authored-by: Beatriz Navidad Vilches Co-authored-by: Nick Breed --- .gitlab-ci.yml | 99 +- CHANGELOG.md | 16 + CMakeLists.txt | 52 +- benchmark/CMakeLists.txt | 9 +- benchmark/ConfigAutotuneSettings.cmake | 12 + .../benchmark_block_run_length_decode.cpp | 3 +- .../benchmark_device_histogram.parallel.hpp | 6 +- .../benchmark_device_segmented_radix_sort.cpp | 521 -- ...hmark_device_segmented_radix_sort_keys.cpp | 322 ++ ..._segmented_radix_sort_keys.parallel.cpp.in | 34 + ...ice_segmented_radix_sort_keys.parallel.hpp | 373 ++ ...mark_device_segmented_radix_sort_pairs.cpp | 357 ++ ...segmented_radix_sort_pairs.parallel.cpp.in | 35 + ...ce_segmented_radix_sort_pairs.parallel.hpp | 412 ++ benchmark/benchmark_utils.hpp | 10 +- benchmark/benchmark_warp_exchange.cpp | 99 +- benchmark/benchmark_warp_scan.cpp | 11 +- docs/conf.py | 41 +- docs/device_ops/adjacent_difference.rst | 10 + docs/device_ops/index.rst | 1 + docs/device_ops/memcpy.rst | 19 + docs/device_ops/partition.rst | 5 + docs/reference/intrinsics.rst | 2 + docs/reference/ops_summary.rst | 5 +- docs/sphinx/_toc.yml.in | 4 +- rocprim/CMakeLists.txt | 59 +- .../rocprim/detail/match_result_type.hpp | 111 - .../include/rocprim/detail/temp_storage.hpp | 12 +- rocprim/include/rocprim/detail/various.hpp | 22 +- .../config/device_adjacent_difference.hpp | 6 +- .../device_adjacent_difference_inplace.hpp | 4 +- .../device/detail/config/device_histogram.hpp | 4 +- .../device/detail/config/device_reduce.hpp | 4 +- .../device/detail/config/device_scan.hpp | 4 +- .../detail/config/device_scan_by_key.hpp | 4 +- .../config/device_segmented_radix_sort.hpp | 4886 +++++++++++++++++ .../device/detail/device_config_helper.hpp | 237 +- .../device/detail/device_partition.hpp | 216 +- .../device/detail/device_reduce_by_key.hpp | 5 +- .../rocprim/device/detail/device_scan.hpp | 30 +- .../device/detail/device_scan_by_key.hpp | 1 + .../detail/device_segmented_radix_sort.hpp | 152 +- .../device/detail/device_transform.hpp | 5 +- .../device/device_adjacent_difference.hpp | 214 +- .../device_adjacent_difference_config.hpp | 5 + .../include/rocprim/device/device_reduce.hpp | 8 +- .../rocprim/device/device_reduce_by_key.hpp | 3 +- .../include/rocprim/device/device_scan.hpp | 306 +- .../rocprim/device/device_scan_by_key.hpp | 131 +- .../device/device_segmented_radix_sort.hpp | 488 +- .../device_segmented_radix_sort_config.hpp | 329 +- .../device/device_segmented_reduce.hpp | 11 +- .../rocprim/device/device_segmented_scan.hpp | 3 +- .../rocprim/device/device_transform.hpp | 7 +- .../device/device_transform_config.hpp | 4 +- rocprim/include/rocprim/intrinsics/warp.hpp | 6 +- .../rocprim/iterator/transform_iterator.hpp | 17 +- .../rocprim/thread/thread_operators.hpp | 4 +- rocprim/include/rocprim/type_traits.hpp | 118 +- .../rocprim/warp/detail/warp_scan_dpp.hpp | 22 +- .../warp/detail/warp_scan_shared_mem.hpp | 15 +- .../rocprim/warp/detail/warp_scan_shuffle.hpp | 17 +- .../include/rocprim/warp/warp_exchange.hpp | 14 +- rocprim/include/rocprim/warp/warp_scan.hpp | 89 +- scripts/autotune/create_optimization.py | 12 + .../adjacent_difference_config_template | 2 +- ...djacent_difference_inplace_config_template | 2 +- .../templates/histogram_config_template | 2 +- .../autotune/templates/reduce_config_template | 2 +- .../autotune/templates/scan_config_template | 2 +- .../templates/scanbykey_config_template | 2 +- .../segmented_radix_sort_config_template | 20 + test/CMakeLists.txt | 20 +- test/common_test_header.hpp | 18 +- test/hipgraph/test_hipgraph_algs.cpp | 4 +- test/rocprim/CMakeLists.txt | 9 +- test/rocprim/indirect_iterator.hpp | 179 + .../test_device_adjacent_difference.cpp | 463 +- test/rocprim/test_device_radix_sort.cpp.in | 2 +- test/rocprim/test_device_scan.cpp | 179 +- .../test_device_segmented_radix_sort.hpp | 79 +- test/rocprim/test_device_select.cpp | 200 +- test/rocprim/test_invoke_result.cpp | 123 + test/rocprim/test_utils.hpp | 11 +- test/rocprim/test_utils_assertions.hpp | 3 +- test/rocprim/test_utils_types.hpp | 8 +- test/rocprim/test_warp_exchange.cpp | 252 +- test/rocprim/test_warp_load.cpp | 139 +- test/rocprim/test_warp_scan.hpp | 312 ++ test/rocprim/test_warp_scan.kernels.hpp | 45 +- test/rocprim/test_warp_store.cpp | 134 +- 91 files changed, 9911 insertions(+), 2344 deletions(-) delete mode 100644 benchmark/benchmark_device_segmented_radix_sort.cpp create mode 100644 benchmark/benchmark_device_segmented_radix_sort_keys.cpp create mode 100644 benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in create mode 100644 benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp create mode 100644 benchmark/benchmark_device_segmented_radix_sort_pairs.cpp create mode 100644 benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in create mode 100644 benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp create mode 100644 docs/device_ops/memcpy.rst delete mode 100644 rocprim/include/rocprim/detail/match_result_type.hpp create mode 100644 rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp create mode 100644 scripts/autotune/templates/segmented_radix_sort_config_template create mode 100644 test/rocprim/indirect_iterator.hpp create mode 100644 test/rocprim/test_invoke_result.cpp diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 66c0092e6..3bf7882b8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,7 @@ include: - /deps-docs.yaml - /deps-rocm.yaml - /deps-vcpkg.yaml + - /deps-windows.yaml - /gpus-rocm.yaml - /rules.yaml @@ -247,6 +248,7 @@ build:benchmark: -S $CI_PROJECT_DIR -G Ninja -D CMAKE_CXX_COMPILER="$AMDCLANG" + -D CMAKE_CXX_FLAGS="-Wall -Wextra -Werror -Wno-#pragma-messages" -D CMAKE_BUILD_TYPE=Release -D BUILD_TEST=OFF -D BUILD_EXAMPLE=OFF @@ -260,6 +262,48 @@ build:benchmark: - $BUILD_DIR/deps/googlebenchmark/ expire_in: 2 weeks +build:windows: + stage: build + needs: [] + extends: + - .rules:build + - .gpus:rocm-windows + - .deps:rocm-windows + - .deps:visual-studio-devshell + parallel: + matrix: + - BUILD_TYPE: + # Disabled due to extensive link times. + # This is tracked in issue 679 + #- Debug + - Release + BUILD_TARGET: + - BENCHMARK + - TEST + script: + - mkdir -p $CI_PROJECT_DIR/build + - cmake -G Ninja + -S $CI_PROJECT_DIR + -B $CI_PROJECT_DIR/build + -D BUILD_$BUILD_TARGET=ON + -D GPU_TARGETS=$GPU_TARGET + -D CMAKE_CXX_COMPILER:PATH="${env:HIP_PATH}\bin\clang++.exe" + -D CMAKE_C_COMPILER:PATH="${env:HIP_PATH}\bin\clang.exe" + -D CMAKE_PREFIX_PATH:PATH="${env:HIP_PATH}" + -D CMAKE_BUILD_TYPE="$BUILD_TYPE" + - cmake --build "$CI_PROJECT_DIR/build" + artifacts: + paths: + - $CI_PROJECT_DIR/build/test/test_* + - $CI_PROJECT_DIR/build/test/rocprim/test_* + - $CI_PROJECT_DIR/build/test/CTestTestfile.cmake + - $CI_PROJECT_DIR/build/test/rocprim/CTestTestfile.cmake + - $CI_PROJECT_DIR/build/gtest/ + - $CI_PROJECT_DIR/build/CMakeCache.txt + - $CI_PROJECT_DIR/build/.ninja_log + - $CI_PROJECT_DIR/build/CTestTestfile.cmake + expire_in: 2 weeks + autotune:build: stage: autotune needs: [] @@ -289,6 +333,19 @@ autotune:build: -D GPU_TARGETS=$GPU_TARGETS - cmake --build . --target $BENCHMARK_TARGETS - 'rm -rf $BUILD_DIR/benchmark/benchmark*.parallel' + # remove benchmark executables if their size together is too large for gitlab ci to handle + - | + total_size_bytes=0 + while read -r file_size; do + total_size_bytes=$((total_size_bytes + file_size)) + done < <(stat --format="%s" benchmark/benchmark*) + total_size_gib="$(numfmt --round=down --to-unit=Gi "$total_size_bytes")" + if [ "$total_size_gib" -ge 3 ]; then + printf "Total size: %s (%d bytes) > 3GiB, skipping benchmark executables from the artifact.\n" \ + "$(numfmt --to=iec-i "$total_size_bytes")" "$total_size_bytes" + rm benchmark/benchmark* + fi + artifacts: paths: - $BUILD_DIR/benchmark/benchmark* @@ -320,6 +377,39 @@ test: --resource-spec-file ./resources.json --parallel $PARALLEL_JOBS +.test-windows-base: + stage: test + extends: + - .deps:rocm-windows + - .gpus:rocm-gpus-windows + - .deps:visual-studio-devshell + - .rules:test + script: + - cd $CI_PROJECT_DIR/build + - ctest --output-on-failure + +# Disabled due to extensive link times. +# This is tracked in issue 679 +# test-windows-debug: +# extends: +# - .test-windows-base +# needs: +# - job: build:windows +# parallel: +# matrix: +# - BUILD_TYPE: Debug +# BUILD_TARGET: TEST + +test-windows-release: + extends: + - .test-windows-base + needs: + - job: build:windows + parallel: + matrix: + - BUILD_TYPE: Release + BUILD_TARGET: TEST + .test-package: script: - cmake @@ -369,6 +459,8 @@ test:deb: test:docs: stage: test + variables: + SPHINX_DIR: $DOCS_DIR/sphinx extends: - .rules:test - .build:docs @@ -472,6 +564,11 @@ autotune:execute-tuning: # On ROCm 5.7 or later, check if this can be removed - the presumption is that the failure is caused by a compiler issue. - > cd "${CI_PROJECT_DIR}" + - | + if [ ! -d "${BUILD_DIR}/benchmark" ]; then + echo "There are no benchmark executables. Run the build job with a BUILD_TARGET." + exit 1 + fi - mkdir -p "${AUTOTUNE_RESULT_DIR}" - python3 .gitlab/run_benchmarks.py diff --git a/CHANGELOG.md b/CHANGELOG.md index baefd45a2..09a057021 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,22 @@ Documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). +## Unreleased rocPRIM-3.2.0 for ROCm 6.2.0 + +### Additions + +* New overloads for `warp_scan::exclusive_scan` that take no initial value. These new overloads will write an unspecified result to the first value of each warp. +* The internal accumulator type of `inclusive_scan(_by_key)` and `exclusive_scan(_by_key)` is now exposed as an optional type parameter. + * The default accumulator type is still the value type of the input iterator (inclusive scan) or the initial value's type (exclusive scan). + This is the same behaviour as before this change. +* New overload for `device_adjacent_difference_inplace` that allows separate input and output iterators, but allows them to point to the same element. + +### Fixes + +* Fixed incorrect results of `warp_exchange::blocked_to_striped_shuffle` and `warp_exchange::striped_to_blocked_shuffle` when the block size is + larger than the logical warp size. The test suite has been updated with such cases. +* Fixed incorrect results returned when calling device `unique_by_key` with overlapping `values_input` and `values_output`. + ## Unreleased rocPRIM-3.1.0 for ROCm 6.1.0 ### Additions diff --git a/CMakeLists.txt b/CMakeLists.txt index b2725db26..b92c8ab8c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,12 @@ set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "Install path prefix, prepended # rocPRIM project project(rocprim LANGUAGES CXX) +if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) + set(ROCPRIM_PROJECT_IS_TOP_LEVEL TRUE) +else() + set(ROCPRIM_PROJECT_IS_TOP_LEVEL FALSE) +endif() + #Adding CMAKE_PREFIX_PATH if(WIN32) set(ROCM_ROOT "$ENV{HIP_PATH}" CACHE PATH "Root directory of the ROCm installation") @@ -44,6 +50,7 @@ option(USE_HIP_CPU "Prefer HIP-CPU runtime instead of HW acceleration" OFF) # Disables building tests, benchmarks, examples option(ONLY_INSTALL "Only install" OFF) option(BUILD_CODE_COVERAGE "Build with code coverage enabled" OFF) +option(ROCPRIM_INSTALL "Enable installation of rocPRIM (projects embedding rocPRIM may want to turn this OFF)" ON) # CMake modules list(APPEND CMAKE_MODULE_PATH @@ -94,7 +101,7 @@ endif() # FOR HANDLING ENABLE/DISABLE OPTIONAL BACKWARD COMPATIBILITY for FILE/FOLDER REORG option(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY "Build with file/folder reorg with backward compatibility enabled" OFF) -if(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) +if(ROCPRIM_INSTALL AND BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) rocm_wrap_header_dir( "${PROJECT_SOURCE_DIR}/rocprim/include/rocprim" WRAPPER_LOCATIONS rocprim/include/rocprim @@ -114,7 +121,7 @@ if(USE_HIP_CPU) endif() # Setup VERSION -set(VERSION_STRING "3.1.0") +set(VERSION_STRING "3.2.0") rocm_setup_version(VERSION ${VERSION_STRING}) # Print configuration summary @@ -124,20 +131,24 @@ print_configuration_summary() # rocPRIM library add_subdirectory(rocprim) -if(NOT ONLY_INSTALL AND (BUILD_TEST OR BUILD_BENCHMARK)) +if(ROCPRIM_PROJECT_IS_TOP_LEVEL AND NOT ONLY_INSTALL AND (BUILD_TEST OR BUILD_BENCHMARK)) rocm_package_setup_component(clients) endif() # Tests if(BUILD_TEST AND NOT ONLY_INSTALL) - rocm_package_setup_client_component(tests) + if (ROCPRIM_PROJECT_IS_TOP_LEVEL) + rocm_package_setup_client_component(tests) + endif() enable_testing() add_subdirectory(test) endif() # Benchmarks if(BUILD_BENCHMARK AND NOT ONLY_INSTALL) - rocm_package_setup_client_component(benchmarks) + if (ROCPRIM_PROJECT_IS_TOP_LEVEL) + rocm_package_setup_client_component(benchmarks) + endif() add_subdirectory(benchmark) endif() @@ -147,20 +158,21 @@ if(BUILD_EXAMPLE AND NOT ONLY_INSTALL) endif() # Package -set(BUILD_SHARED_LIBS ON) # Build as though shared library for naming -rocm_package_add_dependencies(DEPENDS "hip-rocclr >= 3.5.0") -set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE.txt") -set(CPACK_RPM_PACKAGE_LICENSE "MIT") - -set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION "\${CPACK_PACKAGING_INSTALL_PREFIX}" ) - -rocm_create_package( - NAME rocprim - DESCRIPTION "Radeon Open Compute Parallel Primitives Library" - MAINTAINER "rocPRIM Maintainer " - HEADER_ONLY -) - +if (ROCPRIM_PROJECT_IS_TOP_LEVEL) + set(BUILD_SHARED_LIBS ON) # Build as though shared library for naming + rocm_package_add_dependencies(DEPENDS "hip-rocclr >= 3.5.0") + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE.txt") + set(CPACK_RPM_PACKAGE_LICENSE "MIT") + + set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION "\${CPACK_PACKAGING_INSTALL_PREFIX}" ) + + rocm_create_package( + NAME rocprim + DESCRIPTION "Radeon Open Compute Parallel Primitives Library" + MAINTAINER "rocPRIM Maintainer " + HEADER_ONLY + ) +endif() # # ADDITIONAL TARGETS FOR CODE COVERAGE diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 499291a39..5bde961ed 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -95,7 +95,9 @@ function(add_rocprim_benchmark BENCHMARK_SOURCE) RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/benchmark" ) - rocm_install(TARGETS ${BENCHMARK_TARGET} COMPONENT benchmarks) + if (ROCPRIM_INSTALL) + rocm_install(TARGETS ${BENCHMARK_TARGET} COMPONENT benchmarks) + endif() if (WIN32 AND NOT DEFINED DLLS_COPIED) set(DLLS_COPIED "YES") set(DLLS_COPIED ${DLLS_COPIED} PARENT_SCOPE) @@ -145,7 +147,8 @@ add_rocprim_benchmark(benchmark_device_run_length_encode.cpp) add_rocprim_benchmark(benchmark_device_scan.cpp) add_rocprim_benchmark(benchmark_device_scan_by_key.cpp) add_rocprim_benchmark(benchmark_device_select.cpp) -add_rocprim_benchmark(benchmark_device_segmented_radix_sort.cpp) +add_rocprim_benchmark(benchmark_device_segmented_radix_sort_keys.cpp) +add_rocprim_benchmark(benchmark_device_segmented_radix_sort_pairs.cpp) add_rocprim_benchmark(benchmark_device_segmented_reduce.cpp) add_rocprim_benchmark(benchmark_device_transform.cpp) add_rocprim_benchmark(benchmark_warp_exchange.cpp) diff --git a/benchmark/ConfigAutotuneSettings.cmake b/benchmark/ConfigAutotuneSettings.cmake index 510c222ad..85556d894 100644 --- a/benchmark/ConfigAutotuneSettings.cmake +++ b/benchmark/ConfigAutotuneSettings.cmake @@ -81,5 +81,17 @@ ${TUNING_TYPES};${LIMITED_TUNING_TYPES};using_warp_scan reduce_then_scan" PARENT set(list_across "\ binary_search upper_bound lower_bound;${TUNING_TYPES};${LIMITED_TUNING_TYPES};64 128 256;1 2 4 8 16" PARENT_SCOPE) set(output_pattern_suffix "@SubAlgorithm@_@ValueType@_@OutputType@_@BlockSize@_@ItemsPerThread@" PARENT_SCOPE) + elseif(file STREQUAL "benchmark_device_segmented_radix_sort_keys") + set(list_across_names "\ +KeyType;BlockSize;ItemsPerThread;PartitionAllowed" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};128 256;4 8 16;false" PARENT_SCOPE) + set(output_pattern_suffix "\ +@KeyType@_@BlockSize@_@ItemsPerThread@_@PartitionAllowed@" PARENT_SCOPE) + elseif(file STREQUAL "benchmark_device_segmented_radix_sort_pairs") + set(list_across_names "\ +KeyType;ValueType;BlockSize;ItemsPerThread;PartitionAllowed" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};int8_t;64;4 8 16;true false" PARENT_SCOPE) + set(output_pattern_suffix "\ +@KeyType@_@ValueType@_@BlockSize@_@ItemsPerThread@_@PartitionAllowed@" PARENT_SCOPE) endif() endfunction() diff --git a/benchmark/benchmark_block_run_length_decode.cpp b/benchmark/benchmark_block_run_length_decode.cpp index 04e1f0428..34adfdc14 100644 --- a/benchmark/benchmark_block_run_length_decode.cpp +++ b/benchmark/benchmark_block_run_length_decode.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -57,7 +57,6 @@ __global__ rocprim::block_load_direct_blocked(global_thread_idx, d_run_items, run_items); rocprim::block_load_direct_blocked(global_thread_idx, d_run_offsets, run_offsets); - ROCPRIM_SHARED_MEMORY typename BlockRunLengthDecodeT::storage_type temp_storage; BlockRunLengthDecodeT block_run_length_decode(run_items, run_offsets); const OffsetT total_decoded_size diff --git a/benchmark/benchmark_device_histogram.parallel.hpp b/benchmark/benchmark_device_histogram.parallel.hpp index 6137ba7ae..146c1cc40 100644 --- a/benchmark/benchmark_device_histogram.parallel.hpp +++ b/benchmark/benchmark_device_histogram.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -358,8 +358,8 @@ struct device_histogram_benchmark_generator template - auto create(std::vector>& storage, - const std::vector& cases) -> + auto create(std::vector>& /*storage*/, + const std::vector& /*cases*/) -> typename std::enable_if::type {} diff --git a/benchmark/benchmark_device_segmented_radix_sort.cpp b/benchmark/benchmark_device_segmented_radix_sort.cpp deleted file mode 100644 index e471f5201..000000000 --- a/benchmark/benchmark_device_segmented_radix_sort.cpp +++ /dev/null @@ -1,521 +0,0 @@ -// MIT License -// -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#include -#include -#include -#include -#include - -// Google Benchmark -#include "benchmark/benchmark.h" -// CmdParser -#include "cmdparser.hpp" -#include "benchmark_utils.hpp" - -// HIP API -#include - -// rocPRIM -#include - -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; -#endif - -namespace rp = rocprim; - -namespace -{ - -constexpr unsigned int warmup_size = 2; -constexpr size_t min_size = 30000; -constexpr std::array segment_counts{ 10, 100, 1000, 2500, 5000, 7500, 10000, 100000 }; -constexpr std::array segment_lengths{30, 256, 3000, 300000}; -} - - -template -void run_sort_keys_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_size, - hipStream_t stream) -{ - using offset_type = int; - using key_type = Key; - - std::vector offsets; - offsets.push_back(0); - - static constexpr int seed = 716; - std::default_random_engine gen(seed); - - std::normal_distribution segment_length_dis(static_cast(mean_segment_length), - 0.1 * mean_segment_length); - - size_t offset = 0; - for(size_t segment_index = 0; segment_index < num_segments;) - { - const double segment_length_candidate = std::round(segment_length_dis(gen)); - if (segment_length_candidate < 0) - { - continue; - } - const offset_type segment_length = static_cast(segment_length_candidate); - offset += segment_length; - offsets.push_back(offset); - ++segment_index; - } - const size_t size = offset; - const size_t segments_count = offsets.size() - 1; - - std::vector keys_input; - if(std::is_floating_point::value) - { - keys_input = get_random_data( - size, - static_cast(-1000), - static_cast(1000) - ); - } - else - { - keys_input = get_random_data( - size, - std::numeric_limits::min(), - std::numeric_limits::max() - ); - } - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } - - offset_type * d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); - HIP_CHECK( - hipMemcpy( - d_offsets, offsets.data(), - offsets.size() * sizeof(offset_type), - hipMemcpyHostToDevice - ) - ); - - key_type * d_keys_input; - key_type * d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK( - hipMemcpy( - d_keys_input, keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - - void * d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK( - rp::segmented_radix_sort_keys( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_keys( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_keys( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); -} - -template -void run_sort_pairs_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_size, - hipStream_t stream) -{ - using offset_type = int; - using key_type = Key; - using value_type = Value; - - // Generate data - std::vector offsets; - offsets.push_back(0); - - static constexpr int seed = 716; - std::default_random_engine gen(seed); - - std::normal_distribution segment_length_dis(static_cast(mean_segment_length), - 0.1 * mean_segment_length); - - size_t offset = 0; - for(size_t segment_index = 0; segment_index < num_segments;) - { - const double segment_length_candidate = std::round(segment_length_dis(gen)); - if (segment_length_candidate < 0) - { - continue; - } - const offset_type segment_length = static_cast(segment_length_candidate); - offset += segment_length; - offsets.push_back(offset); - ++segment_index; - } - const size_t size = offset; - const size_t segments_count = offsets.size() - 1; - - std::vector keys_input; - if(std::is_floating_point::value) - { - keys_input = get_random_data( - size, - static_cast(-1000), - static_cast(1000) - ); - } - else - { - keys_input = get_random_data( - size, - std::numeric_limits::min(), - std::numeric_limits::max() - ); - } - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } - - std::vector values_input(size); - std::iota(values_input.begin(), values_input.end(), 0); - - offset_type * d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, (segments_count + 1) * sizeof(offset_type))); - HIP_CHECK( - hipMemcpy( - d_offsets, offsets.data(), - (segments_count + 1) * sizeof(offset_type), - hipMemcpyHostToDevice - ) - ); - - key_type * d_keys_input; - key_type * d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK( - hipMemcpy( - d_keys_input, keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - - value_type * d_values_input; - value_type * d_values_output; - HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); - HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); - HIP_CHECK( - hipMemcpy( - d_values_input, values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice - ) - ); - - void * d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK( - rp::segmented_radix_sort_pairs( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, d_values_input, d_values_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_pairs( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, d_values_input, d_values_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; i++) - { - HIP_CHECK( - rp::segmented_radix_sort_pairs( - d_temporary_storage, temporary_storage_bytes, - d_keys_input, d_keys_output, d_values_input, d_values_output, size, - segments_count, d_offsets, d_offsets + 1, - 0, sizeof(key_type) * 8, - stream, false - ) - ); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed( - state.iterations() * batch_size * size * (sizeof(key_type) + sizeof(value_type)) - ); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values_output)); -} - -template -void add_sort_keys_benchmarks(std::vector &benchmarks, - hipStream_t stream, - size_t max_size, - size_t min_size, - size_t target_size) -{ - std::string key_name = Traits::name(); - std::string value_name = Traits::name(); - for(const auto segment_count : segment_counts) - { - for(const auto segment_length : segment_lengths) - { - const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > max_size || number_of_elements < min_size) - { - continue; - } - benchmarks.push_back(benchmark::RegisterBenchmark( - bench_naming::format_name( - "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" - + value_name + ",segment_count:" + std::to_string(segment_count) - + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") - .c_str(), - [=](benchmark::State& state) { - run_sort_keys_benchmark(state, - segment_count, - segment_length, - target_size, - stream); - })); - } - } -} - -template -void add_sort_pairs_benchmarks(std::vector &benchmarks, - hipStream_t stream, - size_t max_size, - size_t min_size, - size_t target_size) -{ - std::string key_name = Traits::name(); - std::string value_name = Traits::name(); - for(const auto segment_count : segment_counts) - { - for(const auto segment_length : segment_lengths) - { - const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > max_size || number_of_elements < min_size) - { - continue; - } - benchmarks.push_back(benchmark::RegisterBenchmark( - bench_naming::format_name( - "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" - + value_name + ",segment_count:" + std::to_string(segment_count) - + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") - .c_str(), - [=](benchmark::State& state) - { - run_sort_pairs_benchmark(state, - segment_count, - segment_length, - target_size, - stream); - })); - } - } -} - -int main(int argc, char *argv[]) -{ - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - - // Add benchmarks - std::vector benchmarks; - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); - - using custom_float2 = custom_type; - using custom_double2 = custom_type; - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; -} diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp new file mode 100644 index 000000000..dfd7e14e9 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp @@ -0,0 +1,322 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include + +// Google Benchmark +#include "benchmark/benchmark.h" +// CmdParser +#include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +// HIP API +#include + +// rocPRIM +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 32; +#endif + +namespace rp = rocprim; + +namespace +{ + +constexpr unsigned int warmup_size = 2; +constexpr size_t min_size = 30000; +constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; +constexpr std::array segment_lengths{30, 256, 3000, 300000}; +} // namespace + +// This benchmark only handles the rocprim::segmented_radix_sort_keys function. The benchmark was separated into two (keys and pairs), +// because the binary became too large to link. Runs into a "relocation R_X86_64_PC32 out of range" error. +// This happens partially, because of the algorithm has 4 kernels, and decides at runtime which one to call. + +template +void run_sort_keys_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) +{ + using offset_type = int; + using key_type = Key; + + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis(static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + offsets.size() * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK( + hipMemcpy(d_keys_input, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rp::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); +} + +template +void add_sort_keys_benchmarks(std::vector& benchmarks, + hipStream_t stream, + size_t max_size, + size_t min_size, + size_t target_size) +{ + std::string key_name = Traits::name(); + std::string value_name = Traits::name(); + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > max_size || number_of_elements < min_size) + { + continue; + } + benchmarks.push_back(benchmark::RegisterBenchmark( + bench_naming::format_name( + "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" + + value_name + ",segment_count:" + std::to_string(segment_count) + + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") + .c_str(), + [=](benchmark::State& state) { + run_sort_keys_benchmark(state, + segment_count, + segment_length, + target_size, + stream); + })); + } + } +} + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + +#ifdef BENCHMARK_CONFIG_TUNING + // optionally run an evenly split subset of benchmarks, when making multiple program invocations + parser.set_optional("parallel_instance", + "parallel_instance", + 0, + "parallel instance index"); + parser.set_optional("parallel_instances", + "parallel_instances", + 1, + "total parallel instances"); +#endif + + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + + // Add benchmarks + std::vector benchmarks; +#ifdef BENCHMARK_CONFIG_TUNING + const int parallel_instance = parser.get("parallel_instance"); + const int parallel_instances = parser.get("parallel_instances"); + config_autotune_register::register_benchmark_subset(benchmarks, + parallel_instance, + parallel_instances, + size, + stream); +#else + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_keys_benchmarks(benchmarks, stream, size, min_size, size / 2); +#endif + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in new file mode 100644 index 000000000..4913fdff7 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in @@ -0,0 +1,34 @@ +// MIT License +// +// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "benchmark_device_segmented_radix_sort_keys.parallel.hpp" +#include "benchmark_utils.hpp" + +namespace +{ +auto benchmarks = config_autotune_register::create_bulk(device_segmented_radix_sort_benchmark_generator<@BlockSize@, + @ItemsPerThread@, + @KeyType@, + @PartitionAllowed@>::create); +} // namespace diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp new file mode 100644 index 000000000..4227d223a --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp @@ -0,0 +1,373 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ + +#include +#include +#include + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include "benchmark_utils.hpp" + +template +std::string warp_sort_config_name(T const& warp_sort_config) +{ + return "{pa:" + std::to_string(warp_sort_config.partitioning_allowed) + + ",lwss:" + std::to_string(warp_sort_config.logical_warp_size_small) + + ",ipts:" + std::to_string(warp_sort_config.items_per_thread_small) + + ",bss:" + std::to_string(warp_sort_config.block_size_small) + + ",pt:" + std::to_string(warp_sort_config.partitioning_threshold) + + ",lwsm:" + std::to_string(warp_sort_config.logical_warp_size_medium) + + ",iptm:" + std::to_string(warp_sort_config.items_per_thread_medium) + + ",bsm:" + std::to_string(warp_sort_config.block_size_medium) + "}"; +} + +template +std::string config_name() +{ + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return "{bs:" + std::to_string(config.kernel_config.block_size) + + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) + + ",lrb:" + std::to_string(config.long_radix_bits) + + ",srb:" + std::to_string(config.short_radix_bits) + + ",eupws:" + std::to_string(config.enable_unpartitioned_warp_sort) + + ",wsc:" + warp_sort_config_name(config.warp_sort_config) + "}"; +} + +template<> +inline std::string config_name() +{ + return "default_config"; +} + +template +struct device_segmented_radix_sort_benchmark : public config_autotune_interface +{ + std::string name() const override + { + using namespace std::string_literals; + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return bench_naming::format_name( + "{lvl:device,algo:segmented_radix_sort,key_type:" + std::string(Traits::name()) + + ",value_type:empty_type" + ",cfg:" + config_name() + "}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) const + { + using offset_type = int; + using key_type = Key; + + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis( + static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + offsets.size() * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + } + + void run(benchmark::State& state, size_t size, hipStream_t stream) const override + { + constexpr std::array + segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + constexpr std::array segment_lengths{30, 256, 3000, 300000}; + + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > 33554432 || number_of_elements < 300000) + { + continue; + } + + run_benchmark(state, segment_count, segment_length, size, stream); + } + } + } +}; + +template class T, bool enable, Tp... Idx> +struct decider; +template +struct device_segmented_radix_sort_benchmark_generator +{ + template + struct create_lrb + { + template + struct create_srb + { + template + struct create_euws + { + template + struct create_lwss + { + template + struct create_pt + { + void operator()( + std::vector>& storage) + { + storage.emplace_back( + std::make_unique>>>()); + } + }; + + void + operator()(std::vector>& storage) + { + static_for_each, create_pt>(storage); + } + }; + + void operator()(std::vector>& storage) + { + if(PartitionAllowed) + { + + static_for_each, + create_lwss>(storage); + } + else + { + storage.emplace_back(std::make_unique>>()); + } + } + }; + + void operator()(std::vector>& storage) + { + decider::do_the_thing( + storage); + } + }; + + void operator()(std::vector>& storage) + { + decider::do_the_thing( + storage); + } + }; + + static void create(std::vector>& storage) + { + static_for_each, create_lrb>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& storage) + { + static_for_each, T>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& /*storage*/) + {} +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp new file mode 100644 index 000000000..4aba2b436 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp @@ -0,0 +1,357 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include + +// Google Benchmark +#include "benchmark/benchmark.h" +// CmdParser +#include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +// HIP API +#include + +// rocPRIM +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 32; +#endif + +namespace rp = rocprim; + +namespace +{ + +constexpr unsigned int warmup_size = 2; +constexpr size_t min_size = 30000; +constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; +constexpr std::array segment_lengths{30, 256, 3000, 300000}; +} // namespace + +// This benchmark only handles the rocprim::segmented_radix_sort_pairs function. The benchmark was separated into two (keys and pairs), +// because the binary became too large to link. Runs into a "relocation R_X86_64_PC32 out of range" error. +// This happens partially, because of the algorithm has 4 kernels, and decides at runtime which one to call. + +template +void run_sort_pairs_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) +{ + using offset_type = int; + using key_type = Key; + using value_type = Value; + + // Generate data + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis(static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + std::vector values_input(size); + std::iota(values_input.begin(), values_input.end(), 0); + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, (segments_count + 1) * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + (segments_count + 1) * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK( + hipMemcpy(d_keys_input, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); + + value_type* d_values_input; + value_type* d_values_output; + HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); + HIP_CHECK(hipMemcpy(d_values_input, + values_input.data(), + size * sizeof(value_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rp::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rp::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size + * (sizeof(key_type) + sizeof(value_type))); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_values_output)); +} + +template +void add_sort_pairs_benchmarks(std::vector& benchmarks, + hipStream_t stream, + size_t max_size, + size_t min_size, + size_t target_size) +{ + std::string key_name = Traits::name(); + std::string value_name = Traits::name(); + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > max_size || number_of_elements < min_size) + { + continue; + } + benchmarks.push_back(benchmark::RegisterBenchmark( + bench_naming::format_name( + "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" + + value_name + ",segment_count:" + std::to_string(segment_count) + + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") + .c_str(), + [=](benchmark::State& state) + { + run_sort_pairs_benchmark(state, + segment_count, + segment_length, + target_size, + stream); + })); + } + } +} + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + +#ifdef BENCHMARK_CONFIG_TUNING + // optionally run an evenly split subset of benchmarks, when making multiple program invocations + parser.set_optional("parallel_instance", + "parallel_instance", + 0, + "parallel instance index"); + parser.set_optional("parallel_instances", + "parallel_instances", + 1, + "total parallel instances"); +#endif + + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + + // Add benchmarks + std::vector benchmarks; +#ifdef BENCHMARK_CONFIG_TUNING + const int parallel_instance = parser.get("parallel_instance"); + const int parallel_instances = parser.get("parallel_instances"); + config_autotune_register::register_benchmark_subset(benchmarks, + parallel_instance, + parallel_instances, + size, + stream); +#else + using custom_float2 = custom_type; + using custom_double2 = custom_type; + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, + stream, + size, + min_size, + size / 2); + add_sort_pairs_benchmarks(benchmarks, stream, size, min_size, size / 2); + add_sort_pairs_benchmarks(benchmarks, + stream, + size, + min_size, + size / 2); +#endif + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in new file mode 100644 index 000000000..55fc0a849 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in @@ -0,0 +1,35 @@ +// MIT License +// +// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "benchmark_device_segmented_radix_sort_pairs.parallel.hpp" +#include "benchmark_utils.hpp" + +namespace +{ +auto benchmarks = config_autotune_register::create_bulk(device_segmented_radix_sort_benchmark_generator<@BlockSize@, + @ItemsPerThread@, + @KeyType@, + @ValueType@, + @PartitionAllowed@>::create); +} // namespace diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp new file mode 100644 index 000000000..917af3a25 --- /dev/null +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp @@ -0,0 +1,412 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ + +#include +#include +#include + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include "benchmark_utils.hpp" + +template +std::string warp_sort_config_name(T const& warp_sort_config) +{ + return "{pa:" + std::to_string(warp_sort_config.partitioning_allowed) + + ",lwss:" + std::to_string(warp_sort_config.logical_warp_size_small) + + ",ipts:" + std::to_string(warp_sort_config.items_per_thread_small) + + ",bss:" + std::to_string(warp_sort_config.block_size_small) + + ",pt:" + std::to_string(warp_sort_config.partitioning_threshold) + + ",lwsm:" + std::to_string(warp_sort_config.logical_warp_size_medium) + + ",iptm:" + std::to_string(warp_sort_config.items_per_thread_medium) + + ",bsm:" + std::to_string(warp_sort_config.block_size_medium) + "}"; +} + +template +std::string config_name() +{ + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return "{bs:" + std::to_string(config.kernel_config.block_size) + + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) + + ",lrb:" + std::to_string(config.long_radix_bits) + + ",srb:" + std::to_string(config.short_radix_bits) + + ",eupws:" + std::to_string(config.enable_unpartitioned_warp_sort) + + ",wsc:" + warp_sort_config_name(config.warp_sort_config) + "}"; +} + +template<> +inline std::string config_name() +{ + return "default_config"; +} + +template +struct device_segmented_radix_sort_benchmark : public config_autotune_interface +{ + std::string name() const override + { + using namespace std::string_literals; + const rocprim::detail::segmented_radix_sort_config_params config = Config(); + return bench_naming::format_name("{lvl:device,algo:segmented_radix_sort,key_type:" + + std::string(Traits::name()) + + ",value_type:" + std::string(Traits::name()) + + ",cfg:" + config_name() + "}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run_benchmark(benchmark::State& state, + size_t num_segments, + size_t mean_segment_length, + size_t target_size, + hipStream_t stream) const + { + using offset_type = int; + using key_type = Key; + using value_type = Value; + + std::vector offsets; + offsets.push_back(0); + + static constexpr int seed = 716; + std::default_random_engine gen(seed); + + std::normal_distribution segment_length_dis( + static_cast(mean_segment_length), + 0.1 * mean_segment_length); + + size_t offset = 0; + for(size_t segment_index = 0; segment_index < num_segments;) + { + const double segment_length_candidate = std::round(segment_length_dis(gen)); + if(segment_length_candidate < 0) + { + continue; + } + const offset_type segment_length = static_cast(segment_length_candidate); + offset += segment_length; + offsets.push_back(offset); + ++segment_index; + } + const size_t size = offset; + const size_t segments_count = offsets.size() - 1; + + std::vector keys_input; + if(std::is_floating_point::value) + { + keys_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + keys_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + std::vector values_input; + if(std::is_floating_point::value) + { + values_input = get_random_data(size, + static_cast(-1000), + static_cast(1000)); + } + else + { + values_input = get_random_data(size, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + size_t batch_size = 1; + if(size < target_size) + { + batch_size = (target_size + size - 1) / size; + } + + offset_type* d_offsets; + HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + offsets.size() * sizeof(offset_type), + hipMemcpyHostToDevice)); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + value_type* d_values_input; + value_type* d_values_output; + HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); + HIP_CHECK(hipMemcpy(d_values_input, + values_input.data(), + size * sizeof(value_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + 0, + sizeof(key_type) * 8, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_values_output)); + } + + void run(benchmark::State& state, size_t size, hipStream_t stream) const override + { + constexpr std::array + segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + constexpr std::array segment_lengths{30, 256, 3000, 300000}; + + for(const auto segment_count : segment_counts) + { + for(const auto segment_length : segment_lengths) + { + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements > 33554432 || number_of_elements < 300000) + { + continue; + } + + run_benchmark(state, segment_count, segment_length, size, stream); + } + } + } +}; + +template class T, bool enable, Tp... Idx> +struct decider; +template +struct device_segmented_radix_sort_benchmark_generator +{ + template + struct create_lrb + { + template + struct create_srb + { + template + struct create_euws + { + template + struct create_lwss + { + template + struct create_pt + { + void operator()( + std::vector>& storage) + { + storage.emplace_back( + std::make_unique>>>()); + } + }; + + void + operator()(std::vector>& storage) + { + static_for_each, create_pt>(storage); + } + }; + + void operator()(std::vector>& storage) + { + if(PartitionAllowed) + { + + static_for_each, + create_lwss>(storage); + } + else + { + storage.emplace_back(std::make_unique>>()); + } + } + }; + + void operator()(std::vector>& storage) + { + static_for_each, create_euws>(storage); + } + }; + + void operator()(std::vector>& storage) + { + decider::do_the_thing( + storage); + } + }; + + static void create(std::vector>& storage) + { + static_for_each, create_lrb>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& storage) + { + static_for_each, T>(storage); + } +}; + +template class T, Tp... Idx> +struct decider +{ + inline static void + do_the_thing(std::vector>& /*storage*/) + {} +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ diff --git a/benchmark/benchmark_utils.hpp b/benchmark/benchmark_utils.hpp index 2d67ff3fa..e5863e758 100644 --- a/benchmark/benchmark_utils.hpp +++ b/benchmark/benchmark_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -306,12 +306,8 @@ inline bool is_warp_size_supported(const unsigned int required_warp_size, const } template -struct DeviceSelectWarpSize -{ - static constexpr unsigned int value = ::rocprim::device_warp_size() >= LogicalWarpSize - ? LogicalWarpSize - : ::rocprim::device_warp_size(); -}; +__device__ constexpr bool device_test_enabled_for_warp_size_v + = ::rocprim::device_warp_size() >= LogicalWarpSize; template std::vector diff --git a/benchmark/benchmark_warp_exchange.cpp b/benchmark/benchmark_warp_exchange.cpp index 997c2e357..64a0c65a5 100644 --- a/benchmark/benchmark_warp_exchange.cpp +++ b/benchmark/benchmark_warp_exchange.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -124,16 +124,14 @@ struct ScatterToStripedOp } }; -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op -> -__global__ -__launch_bounds__(BlockSize) -auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::enable_if::value, void>::type +template +__device__ auto warp_exchange_benchmark(T* d_output, unsigned int trials) + -> std::enable_if_t + && !std::is_same::value> { T thread_data[ItemsPerThread]; @@ -141,16 +139,12 @@ auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::ena for(unsigned int i = 0; i < ItemsPerThread; i++) { // generate unique value each data-element - thread_data[i] = static_cast(hipThreadIdx_x*ItemsPerThread+i); + thread_data[i] = static_cast(threadIdx.x * ItemsPerThread + i); } - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - DeviceSelectWarpSize::value - >; + using warp_exchange_type = ::rocprim::warp_exchange; constexpr unsigned int warps_in_block = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[warps_in_block]; ROCPRIM_NO_UNROLL @@ -163,44 +157,37 @@ auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::ena ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - const unsigned int global_idx = - (BlockSize * hipBlockIdx_x + hipThreadIdx_x) * ItemsPerThread + i; + const unsigned int global_idx = (BlockSize * blockIdx.x + threadIdx.x) * ItemsPerThread + i; d_output[global_idx] = thread_data[i]; } } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op - > -__global__ -__launch_bounds__(BlockSize) - auto warp_exchange_kernel(T* d_output, unsigned int trials) -> typename std::enable_if::value, void>::type +template +__device__ auto warp_exchange_benchmark(T* d_output, unsigned int trials) + -> std::enable_if_t + && std::is_same::value> { T thread_data[ItemsPerThread]; unsigned int thread_ranks[ItemsPerThread]; constexpr unsigned int warps_in_block = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; - const unsigned int lane_id = hipThreadIdx_x % LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; + const unsigned int lane_id = threadIdx.x % LogicalWarpSize; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { // generate unique value each data-element - thread_data[i] = static_cast(hipThreadIdx_x*ItemsPerThread+i); + thread_data[i] = static_cast(threadIdx.x * ItemsPerThread + i); // generate unique destination location for each data-element const unsigned int s_lane_id = i % 2 == 0 ? LogicalWarpSize - 1 - lane_id : lane_id; thread_ranks[i] = s_lane_id*ItemsPerThread+i; // scatter values in warp across whole storage } - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - DeviceSelectWarpSize::value - >; + using warp_exchange_type = ::rocprim::warp_exchange; ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[warps_in_block]; ROCPRIM_NO_UNROLL @@ -213,12 +200,30 @@ __launch_bounds__(BlockSize) ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - const unsigned int global_idx = - (BlockSize * hipBlockIdx_x + hipThreadIdx_x) * ItemsPerThread + i; + const unsigned int global_idx = (BlockSize * blockIdx.x + threadIdx.x) * ItemsPerThread + i; d_output[global_idx] = thread_data[i]; } } +template +__device__ auto warp_exchange_benchmark(T* /*d_output*/, unsigned int /*trials*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_exchange_kernel(T* d_output, unsigned int trials) +{ + warp_exchange_benchmark(d_output, trials); +} + template< class T, unsigned int BlockSize, @@ -245,18 +250,8 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) // Record start event HIP_CHECK(hipEventRecord(start, stream)); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(warp_exchange_kernel< - T, - BlockSize, - ItemsPerThread, - LogicalWarpSize, - Op - > - ), - dim3(size / items_per_block), dim3(BlockSize), 0, stream, - d_output, trials - ); + warp_exchange_kernel + <<>>(d_output, trials); HIP_CHECK(hipPeekAtLastError()); diff --git a/benchmark/benchmark_warp_scan.cpp b/benchmark/benchmark_warp_scan.cpp index e9ba06674..daee015ae 100644 --- a/benchmark/benchmark_warp_scan.cpp +++ b/benchmark/benchmark_warp_scan.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -179,9 +179,8 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t size) template void add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - hipStream_t stream, - size_t size) + hipStream_t stream, + size_t size) { using custom_double2 = custom_type; using custom_int_double = custom_type; @@ -226,8 +225,8 @@ int main(int argc, char *argv[]) // Add benchmarks std::vector benchmarks; - add_benchmarks(benchmarks, "inclusive", stream, size); - add_benchmarks(benchmarks, "exclusive", stream, size); + add_benchmarks(benchmarks, stream, size); //inclusive + add_benchmarks(benchmarks, stream, size); //exclusive // Use manual timing for(auto& b : benchmarks) diff --git a/docs/conf.py b/docs/conf.py index d70124f3c..58c7445ed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,23 @@ +# Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full @@ -6,17 +26,14 @@ import re -from rocm_docs import ROCmDocs - with open('../CMakeLists.txt', encoding='utf-8') as f: match = re.search(r'.*\bset\(VERSION_STRING\s+\"?([0-9.]+)[^0-9.]+', f.read()) if not match: raise ValueError("VERSION not found!") version_number = match[1] -left_nav_title = f"rocPRIM {version_number} Documentation" -# for PDF output on Read the Docs project = "rocPRIM Documentation" +html_title = f"rocPRIM {version_number} Documentation" author = "Advanced Micro Devices, Inc." copyright = "Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved." version = version_number @@ -24,14 +41,18 @@ external_toc_path = "./sphinx/_toc.yml" -docs_core = ROCmDocs(left_nav_title) -docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") -docs_core.setup() +extensions = ["rocm_docs", "rocm_docs.doxygen"] -external_projects_current_project = "rocprim" +html_theme = "rocm_docs_theme" -for sphinx_var in ROCmDocs.SPHINX_VARS: - globals()[sphinx_var] = getattr(docs_core, sphinx_var) +doxygen_root = "doxygen" +doxygen_project = { + "name": "rocPRIM", + "path": "doxygen/xml", +} + +external_projects = [] +external_projects_current_project = "rocprim" cpp_id_attributes = ["__global__", "__device__", "__host__", "__forceinline__", "static"] cpp_paren_attributes = ["__declspec"] diff --git a/docs/device_ops/adjacent_difference.rst b/docs/device_ops/adjacent_difference.rst index 987614cd4..670d5783f 100644 --- a/docs/device_ops/adjacent_difference.rst +++ b/docs/device_ops/adjacent_difference.rst @@ -23,6 +23,11 @@ left, inplace .. doxygenfunction:: rocprim::adjacent_difference_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt values, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) +left, aliased +~~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::adjacent_difference_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt input, const OutputIt output, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) + right ============= @@ -33,3 +38,8 @@ right, inplace .. doxygenfunction:: rocprim::adjacent_difference_right_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt values, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) +right, aliased +~~~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::adjacent_difference_right_inplace(void *const temporary_storage, std::size_t &storage_size, const InputIt input, const OutputIt output, const std::size_t size, const BinaryFunction op=BinaryFunction {}, const hipStream_t stream=0, const bool debug_synchronous=false) + diff --git a/docs/device_ops/index.rst b/docs/device_ops/index.rst index b75bb2184..1f01a7414 100644 --- a/docs/device_ops/index.rst +++ b/docs/device_ops/index.rst @@ -21,3 +21,4 @@ * :ref:`dev-adjacent_difference` * :ref:`dev-binary_search` * :ref:`dev-histogram` + * :ref:`dev-memcpy` diff --git a/docs/device_ops/memcpy.rst b/docs/device_ops/memcpy.rst new file mode 100644 index 000000000..8330daed5 --- /dev/null +++ b/docs/device_ops/memcpy.rst @@ -0,0 +1,19 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-memcpy: + + +Memcpy +------ + +Configuring the kernel +~~~~~~~~~~~~~~~~~~~~~~ + +.. doxygenstruct:: rocprim::batch_memcpy_config + +batch_memcpy +~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::batch_memcpy(void* temporary_storage, size_t& storage_size, InputBufferItType sources, OutputBufferItType destinations, BufferSizeItType sizes, uint32_t num_copies, hipStream_t stream = hipStreamDefault, bool debug_synchronous = false) diff --git a/docs/device_ops/partition.rst b/docs/device_ops/partition.rst index a10d95920..ecd5636a9 100644 --- a/docs/device_ops/partition.rst +++ b/docs/device_ops/partition.rst @@ -13,6 +13,11 @@ partition .. doxygenfunction:: rocprim::partition(void *temporary_storage, size_t &storage_size, InputIterator input, OutputIterator output, SelectedCountOutputIterator selected_count_output, const size_t size, UnaryPredicate predicate, const hipStream_t stream=0, const bool debug_synchronous=false) +partition_two_way +~~~~~~~~~~~~~~~~~ + +.. doxygenfunction:: rocprim::partition_two_way(void* temporary_storage, size_t& storage_size, InputIterator input, SelectedOutputIterator output_selected, RejectedOutputIterator output_rejected, SelectedCountOutputIterator selected_count_output, const size_t size, Predicate predicate, const hipStream_t stream = 0, const bool debug_synchronous = false) + partition_three_way ====================== diff --git a/docs/reference/intrinsics.rst b/docs/reference/intrinsics.rst index 0aac48f55..5f09f174f 100644 --- a/docs/reference/intrinsics.rst +++ b/docs/reference/intrinsics.rst @@ -52,4 +52,6 @@ Active threads ================== .. doxygenfunction:: rocprim::ballot (int predicate) +.. doxygenfunction:: rocprim::group_elect(lane_mask_type mask) .. doxygenfunction:: rocprim::masked_bit_count (lane_mask_type x, unsigned int add=0) +.. doxygenfunction:: rocprim::match_any(unsigned int label, bool valid = true) diff --git a/docs/reference/ops_summary.rst b/docs/reference/ops_summary.rst index 01d7238fc..95cfa3f93 100644 --- a/docs/reference/ops_summary.rst +++ b/docs/reference/ops_summary.rst @@ -44,8 +44,9 @@ Partition/Merge Data Movement =============== -* ``store`` stores the sequence to a continuous memory zone. There are variations to use an optimized path or to specify how to store the sequence to better fit the access patterns of the CUs -* ``load`` the complementary operations of the above ones +* ``store`` stores the sequence to a continuous memory zone. There are variations to use an optimized path or to specify how to store the sequence to better fit the access patterns of the CUs. +* ``load`` the complementary operations of the above ones. +* ``memcpy`` copies bytes between device sources and destinations Other operations ====================== diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index f0236fd48..07afd927c 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -23,13 +23,14 @@ subtrees: - file: device_ops/sort.rst - file: device_ops/merge.rst - file: device_ops/partition.rst - - file: device_ops/run_lenght_encoding.rst + - file: device_ops/run_length_encoding.rst - file: device_ops/scan.rst - file: device_ops/select.rst - file: device_ops/reduce.rst - file: device_ops/adjacent_difference.rst - file: device_ops/binary_search.rst - file: device_ops/histogram.rst + - file: device_ops/memcpy.rst - file: block_ops/index.rst subtrees: - entries: @@ -60,6 +61,5 @@ subtrees: - file: reference/thread_ops.rst - file: reference/iterators.rst - file: reference/intrinsics.rst - - file: reference/reorder.rst - file: reference/acknowledge.rst - file: license.rst diff --git a/rocprim/CMakeLists.txt b/rocprim/CMakeLists.txt index 2b02c7f2f..037997d7d 100644 --- a/rocprim/CMakeLists.txt +++ b/rocprim/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,7 @@ configure_file( @ONLY ) -if (BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) +if (ROCPRIM_INSTALL AND BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) rocm_wrap_header_file( "rocprim_version.hpp" WRAPPER_LOCATIONS rocprim/include/rocprim @@ -54,35 +54,36 @@ target_link_libraries(rocprim_hip INTERFACE rocprim hip::device) # Installation +if (ROCPRIM_INSTALL) + # We need to install headers manually as rocm_install_targets + # does not support header-only libraries (INTERFACE targets) + rocm_install_targets( + TARGETS rocprim rocprim_hip + ) -# We need to install headers manually as rocm_install_targets -# does not support header-only libraries (INTERFACE targets) -rocm_install_targets( - TARGETS rocprim rocprim_hip -) + rocm_install( + DIRECTORY + "include/" + "${PROJECT_BINARY_DIR}/rocprim/include/" + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.hpp" + PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ + ) -rocm_install( - DIRECTORY - "include/" - "${PROJECT_BINARY_DIR}/rocprim/include/" - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ - FILES_MATCHING - PATTERN "*.h" - PATTERN "*.hpp" - PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ -) + if (BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) + rocm_install( + DIRECTORY + "${PROJECT_BINARY_DIR}/rocprim/wrapper/" + DESTINATION rocprim/ ) + endif() -if (BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) -rocm_install( - DIRECTORY - "${PROJECT_BINARY_DIR}/rocprim/wrapper/" - DESTINATION rocprim/ ) + # Export targets + rocm_export_targets( + TARGETS roc::rocprim roc::rocprim_hip + DEPENDS PACKAGE hip + NAMESPACE roc:: + ) endif() - -# Export targets -rocm_export_targets( - TARGETS roc::rocprim roc::rocprim_hip - DEPENDS PACKAGE hip - NAMESPACE roc:: -) diff --git a/rocprim/include/rocprim/detail/match_result_type.hpp b/rocprim/include/rocprim/detail/match_result_type.hpp deleted file mode 100644 index 3add1f1a2..000000000 --- a/rocprim/include/rocprim/detail/match_result_type.hpp +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2018-2021 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#ifndef ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ -#define ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ - -#include - -#include "../config.hpp" - -BEGIN_ROCPRIM_NAMESPACE -namespace detail -{ - -// invoke_result is based on std::invoke_result. -// The main difference is using ROCPRIM_HOST_DEVICE, this allows to -// use invoke_result with device-only lambdas/functors in host-only functions -// on HIP-clang. - -template -struct is_reference_wrapper : std::false_type {}; -template -struct is_reference_wrapper> : std::true_type {}; - -template -struct invoke_impl { - template - ROCPRIM_HOST_DEVICE - static auto call(F&& f, Args&&... args) - -> decltype(std::forward(f)(std::forward(args)...)); -}; - -template -struct invoke_impl -{ - template::type, - class = typename std::enable_if::value>::type - > - ROCPRIM_HOST_DEVICE - static auto get(T&& t) -> T&&; - - template::type, - class = typename std::enable_if::value>::type - > - ROCPRIM_HOST_DEVICE - static auto get(T&& t) -> decltype(t.get()); - - template::type, - class = typename std::enable_if::value>::type, - class = typename std::enable_if::value>::type - > - ROCPRIM_HOST_DEVICE - static auto get(T&& t) -> decltype(*std::forward(t)); - - template::value>::type - > - ROCPRIM_HOST_DEVICE - static auto call(MT1 B::*pmf, T&& t, Args&&... args) - -> decltype((invoke_impl::get(std::forward(t)).*pmf)(std::forward(args)...)); - - template - ROCPRIM_HOST_DEVICE - static auto call(MT B::*pmd, T&& t) - -> decltype(invoke_impl::get(std::forward(t)).*pmd); -}; - -template::type> -ROCPRIM_HOST_DEVICE -auto INVOKE(F&& f, Args&&... args) - -> decltype(invoke_impl::call(std::forward(f), std::forward(args)...)); - -// Conforming C++14 implementation (is also a valid C++11 implementation): -template -struct invoke_result_impl { }; -template -struct invoke_result_impl(), std::declval()...))), F, Args...> -{ - using type = decltype(INVOKE(std::declval(), std::declval()...)); -}; - -template -struct invoke_result : invoke_result_impl {}; - -template -struct match_result_type -{ - using type = typename invoke_result::type; -}; - -} // end namespace detail -END_ROCPRIM_NAMESPACE - -#endif // ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ diff --git a/rocprim/include/rocprim/detail/temp_storage.hpp b/rocprim/include/rocprim/detail/temp_storage.hpp index b31853165..f9e2a57d6 100644 --- a/rocprim/include/rocprim/detail/temp_storage.hpp +++ b/rocprim/include/rocprim/detail/temp_storage.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -65,7 +65,7 @@ struct simple_partition layout storage_layout; /// Compute the required layout for this type and return it. - layout get_layout() + layout get_layout() const { return this->storage_layout; } @@ -129,13 +129,13 @@ struct linear_partition linear_partition(Ts... sub_partitions) : sub_partitions{sub_partitions...} {} /// Compute the required layout for this type and return it. - layout get_layout() + layout get_layout() const { size_t required_alignment = 1; size_t required_size = 0; for_each_in_tuple(this->sub_partitions, - [&](auto& sub_partition) + [&](const auto& sub_partition) { const auto sub_layout = sub_partition.get_layout(); @@ -197,13 +197,13 @@ struct union_partition union_partition(Ts... sub_partitions) : sub_partitions{sub_partitions...} {} /// Compute the required layout for this type and return it. - layout get_layout() + layout get_layout() const { size_t required_alignment = 1; size_t required_size = 0; for_each_in_tuple(this->sub_partitions, - [&](auto& sub_partition) + [&](const auto& sub_partition) { const auto sub_layout = sub_partition.get_layout(); diff --git a/rocprim/include/rocprim/detail/various.hpp b/rocprim/include/rocprim/detail/various.hpp index 48203e672..1def58e3d 100644 --- a/rocprim/include/rocprim/detail/various.hpp +++ b/rocprim/include/rocprim/detail/various.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -329,19 +329,23 @@ constexpr std::add_const_t* as_const_ptr(T* ptr) return ptr; } -template -ROCPRIM_HOST_DEVICE inline void for_each_in_tuple_impl(::rocprim::tuple& t, - Function f, - ::rocprim::index_sequence) +template +ROCPRIM_HOST_DEVICE inline void + for_each_in_tuple_impl(Tuple&& t, Function&& f, ::rocprim::index_sequence) { - auto swallow = {(f(::rocprim::get(t)), 0)...}; + int swallow[] + = {(std::forward(f)(::rocprim::get(std::forward(t))), 0)...}; (void)swallow; } -template -ROCPRIM_HOST_DEVICE inline void for_each_in_tuple(::rocprim::tuple& t, Function f) +template +ROCPRIM_HOST_DEVICE inline auto for_each_in_tuple(Tuple&& t, Function&& f) + -> void_t>> { - for_each_in_tuple_impl(t, f, ::rocprim::index_sequence_for()); + static constexpr size_t size = tuple_size>::value; + for_each_in_tuple_impl(std::forward(t), + std::forward(f), + ::rocprim::make_index_sequence()); } /// \brief Reinterprets the pointer as another type and increments it to match the alignment of diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp index 8f5cdabdb..139ce7e8b 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,8 +39,10 @@ namespace detail { template -struct default_adjacent_difference_config : default_adjacent_difference_config_base +struct default_adjacent_difference_config + : default_adjacent_difference_config_base::type {}; + // Based on value_type = double template struct default_adjacent_difference_config< diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp index 0718b6e65..ab0df2f81 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -40,7 +40,7 @@ namespace detail template struct default_adjacent_difference_inplace_config - : default_adjacent_difference_config_base + : default_adjacent_difference_config_base::type {}; // Based on value_type = double diff --git a/rocprim/include/rocprim/device/detail/config/device_histogram.hpp b/rocprim/include/rocprim/device/detail/config/device_histogram.hpp index e12d50069..cfd3dc72a 100644 --- a/rocprim/include/rocprim/device/detail/config/device_histogram.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_histogram.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -44,7 +44,7 @@ template struct default_histogram_config - : default_histogram_config_base + : default_histogram_config_base::type {}; // Based on value_type = double, channels = 1, active_channels = 1 diff --git a/rocprim/include/rocprim/device/detail/config/device_reduce.hpp b/rocprim/include/rocprim/device/detail/config/device_reduce.hpp index 54afc42f4..5def2822b 100644 --- a/rocprim/include/rocprim/device/detail/config/device_reduce.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_reduce.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail { template -struct default_reduce_config : default_reduce_config_base +struct default_reduce_config : default_reduce_config_base::type {}; // Based on key_type = double diff --git a/rocprim/include/rocprim/device/detail/config/device_scan.hpp b/rocprim/include/rocprim/device/detail/config/device_scan.hpp index 7fe8e7259..c402b2178 100644 --- a/rocprim/include/rocprim/device/detail/config/device_scan.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail { template -struct default_scan_config : default_scan_config_base +struct default_scan_config : default_scan_config_base::type {}; // Based on value_type = double diff --git a/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp b/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp index f6f375f24..f5b10bce7 100644 --- a/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail { template -struct default_scan_by_key_config : default_scan_by_key_config_base +struct default_scan_by_key_config : default_scan_by_key_config_base::type {}; // Based on key_type = double, value_type = int64_t diff --git a/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp new file mode 100644 index 000000000..51243cfb6 --- /dev/null +++ b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp @@ -0,0 +1,4886 @@ +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ + +#include "../../../type_traits.hpp" +#include "../device_config_helper.hpp" +#include + +/* DO NOT EDIT THIS FILE + * This file is automatically generated by `/scripts/autotune/create_optimization.py`. + * so most likely you want to edit rocprim/device/device_(algo)_config.hpp + */ + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<7, 6>::type +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 3, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx900), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config<4, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config<7, + 2, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<16, 2, 256, 5, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 4, + 1, + typename std::conditional<1, + WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 5, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config<6, + 4, + 128, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 128, 5, 8, 8, 128>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 2, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config<7, + 4, + 128, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 128, 5, 8, 8, 128>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config<6, + 4, + 128, + 8, + 1, + typename std::conditional<1, + WarpSortConfig<4, 4, 128, 5, 8, 8, 128>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 6, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx1102), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 4, + 3, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 4, + 256, + 16, + 1, + typename std::conditional<1, + WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 7, + 6, + 256, + 17, + 1, + typename std::conditional<1, + WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, + DisabledWarpSortConfig>::type> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ \ No newline at end of file diff --git a/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 532b8d8cd..b6947170b 100644 --- a/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -255,7 +255,7 @@ namespace detail { template -struct default_reduce_config_base_helper +struct default_reduce_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -265,10 +265,6 @@ struct default_reduce_config_base_helper ::rocprim::block_reduce_algorithm::using_warp_reduce>; }; -template -struct default_reduce_config_base : default_reduce_config_base_helper::type -{}; - struct scan_config_tag {}; @@ -337,7 +333,7 @@ struct scan_by_key_config_tag {}; template -struct default_scan_config_base_helper +struct default_scan_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -349,10 +345,6 @@ struct default_scan_config_base_helper ::rocprim::block_scan_algorithm::using_warp_scan>; }; -template -struct default_scan_config_base : default_scan_config_base_helper::type -{}; - /// \brief Provides the kernel parameters for exclusive_scan_by_key and inclusive_scan_by_key based /// on autotuned configurations or user-provided configurations. struct scan_by_key_config_params @@ -415,7 +407,7 @@ namespace detail { template -struct default_scan_by_key_config_base_helper +struct default_scan_by_key_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( sizeof(Key) + sizeof(Value), 2 * sizeof(int)); @@ -428,10 +420,6 @@ struct default_scan_by_key_config_base_helper ::rocprim::block_scan_algorithm::using_warp_scan>; }; -template -struct default_scan_by_key_config_base : default_scan_by_key_config_base_helper::type -{}; - struct transform_config_tag {}; @@ -442,6 +430,203 @@ struct transform_config_params } // namespace detail +namespace detail +{ +struct segmented_radix_sort_config_tag +{}; + +struct warp_sort_config_params +{ + /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. + bool partitioning_allowed = false; + /// \brief The number of threads in the logical warp in the small segment processing kernel. + unsigned int logical_warp_size_small = 0; + /// \brief The number of items processed by a thread in the small segment processing kernel. + unsigned int items_per_thread_small = 0; + /// \brief The number of threads per block in the small segment processing kernel. + unsigned int block_size_small = 0; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + unsigned int partitioning_threshold = 0; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + unsigned int logical_warp_size_medium = 0; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + unsigned int items_per_thread_medium = 0; + /// \brief The number of threads per block in the medium segment processing kernel. + unsigned int block_size_medium = 0; +}; + +struct segmented_radix_sort_config_params +{ + /// \brief Kernel start parameters. + kernel_config_params kernel_config{}; + /// \brief Number of bits in long iterations. + unsigned int long_radix_bits = 0; + /// \brief Number of bits in short iterations. + unsigned int short_radix_bits = 0; + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. + bool enable_unpartitioned_warp_sort = true; + /// \brief Warp sort config params + warp_sort_config_params warp_sort_config{}; +}; + +} // namespace detail + +/// \brief Configuration of the warp sort part of the device segmented radix sort operation. +/// Short enough segments are processed on warp level. +/// +/// \tparam LogicalWarpSizeSmall - number of threads in the logical warp of the kernel +/// that processes small segments. +/// \tparam ItemsPerThreadSmall - number of items processed by a thread in the kernel that processes +/// small segments. +/// \tparam BlockSizeSmall - number of threads per block in the kernel which processes the small segments. +/// \tparam PartitioningThreshold - if the number of segments is at least this threshold, the +/// segments are partitioned to a small, a medium and a large segment collection. Both collections +/// are sorted by different kernels. Otherwise, all segments are sorted by a single kernel. +/// \tparam EnableUnpartitionedWarpSort - If set to \p true, warp sort can be used to sort +/// the small segments, even if the total number of segments is below \p PartitioningThreshold. +/// \tparam LogicalWarpSizeMedium - number of threads in the logical warp of the kernel +/// that processes medium segments. +/// \tparam ItemsPerThreadMedium - number of items processed by a thread in the kernel that processes +/// medium segments. +/// \tparam BlockSizeMedium - number of threads per block in the kernel which processes the medium segments. +template +struct WarpSortConfig +{ + static_assert(LogicalWarpSizeSmall * ItemsPerThreadSmall + <= LogicalWarpSizeMedium * ItemsPerThreadMedium, + "The number of items processed by a small warp cannot be larger than the number " + "of items processed by a medium warp"); + + /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. + static constexpr bool partitioning_allowed = true; + /// \brief The number of threads in the logical warp in the small segment processing kernel. + static constexpr unsigned int logical_warp_size_small = LogicalWarpSizeSmall; + /// \brief The number of items processed by a thread in the small segment processing kernel. + static constexpr unsigned int items_per_thread_small = ItemsPerThreadSmall; + /// \brief The number of threads per block in the small segment processing kernel. + static constexpr unsigned int block_size_small = BlockSizeSmall; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + static constexpr unsigned int partitioning_threshold = PartitioningThreshold; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + static constexpr unsigned int logical_warp_size_medium = LogicalWarpSizeMedium; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + static constexpr unsigned int items_per_thread_medium = ItemsPerThreadMedium; + /// \brief The number of threads per block in the medium segment processing kernel. + static constexpr unsigned int block_size_medium = BlockSizeMedium; +}; + +/// \brief Indicates if the warp level sorting is disabled in the +/// device segmented radix sort configuration. +struct DisabledWarpSortConfig +{ + /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. + static constexpr bool partitioning_allowed = false; + /// \brief The number of threads in the logical warp in the small segment processing kernel. + static constexpr unsigned int logical_warp_size_small = 1; + /// \brief The number of items processed by a thread in the small segment processing kernel. + static constexpr unsigned int items_per_thread_small = 1; + /// \brief The number of threads per block in the small segment processing kernel. + static constexpr unsigned int block_size_small = 1; + /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into + /// small and large segment groups, and each group is handled by a different, specialized kernel. + static constexpr unsigned int partitioning_threshold = 0; + /// \brief The number of threads in the logical warp in the medium segment processing kernel. + static constexpr unsigned int logical_warp_size_medium = 1; + /// \brief The number of items processed by a thread in the medium segment processing kernel. + static constexpr unsigned int items_per_thread_medium = 1; + /// \brief The number of threads per block in the medium segment processing kernel. + static constexpr unsigned int block_size_medium = 1; +}; + +/// \brief Configuration for the device-level segmented radix sort operation. +/// \tparam LongRadixBits . +/// \tparam ShortRadixBits . +/// \tparam BlockSize Number of threads in a block. +/// \tparam ItemsPerThread Number of items processed by each thread. +/// \tparam SizeLimit Limit on the number of items for a single kernel launch. +template +struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_params +{ + /// \brief Identifies the algorithm associated to the config. + /// \brief Identifies the algorithm associated to the config. + using tag = detail::segmented_radix_sort_config_tag; +#ifndef DOXYGEN_SHOULD_SKIP_THIS + + /// \brief Number of bits in long iterations. + static constexpr unsigned int long_radix_bits = LongRadixBits; + + /// \brief Number of bits in short iterations. + static constexpr unsigned int short_radix_bits = ShortRadixBits; + + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. + static constexpr bool enable_unpartitioned_warp_sort = EnableUnpartitionedWarpSort; + + /// \brief Limit on the number of items for a single kernel launch. + static constexpr unsigned int size_limit = SizeLimit; + + using warp_sort_config = WarpSortConfig; + + constexpr segmented_radix_sort_config() + : detail::segmented_radix_sort_config_params{ + {BlockSize, ItemsPerThread, SizeLimit}, + LongRadixBits, + ShortRadixBits, + EnableUnpartitionedWarpSort, + {warp_sort_config::partitioning_allowed, + warp_sort_config::logical_warp_size_small, + warp_sort_config::items_per_thread_small, + warp_sort_config::block_size_small, + warp_sort_config::partitioning_threshold, + warp_sort_config::logical_warp_size_medium, + warp_sort_config::items_per_thread_medium, + warp_sort_config::block_size_medium} + } + {} +#endif +}; + +namespace detail +{ +/// \brief Default segmented_radix_sort kernel configurations, such that the maximum shared memory is not exceeded. +/// +/// \tparam LongRadixBits - Long bits used during the sorting. +/// \tparam ShortRadixBits - Short bits used during the sorting. +/// \tparam ItemsPerThread - Items per thread when type Key has size 1. +template +struct default_segmented_radix_sort_config_base +{ + static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( + sizeof(unsigned int) + sizeof(unsigned int), sizeof(int)); + using type = segmented_radix_sort_config>; +}; + +} // namespace detail + /// \brief Configuration for the device-level transform operation. /// \tparam BlockSize Number of threads in a block. /// \tparam ItemsPerThread Number of items processed by each thread. @@ -476,7 +661,7 @@ namespace detail { template -struct default_transform_config_base_helper +struct default_transform_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -484,10 +669,6 @@ struct default_transform_config_base_helper using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; }; -template -struct default_transform_config_base : default_transform_config_base_helper::type -{}; - struct binary_search_config_tag : public transform_config_tag {}; struct upper_bound_config_tag : public transform_config_tag @@ -597,7 +778,7 @@ namespace detail { template -struct default_histogram_config_base_helper +struct default_histogram_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); @@ -606,11 +787,6 @@ struct default_histogram_config_base_helper = histogram_config>; }; -template -struct default_histogram_config_base - : default_histogram_config_base_helper::type -{}; - struct adjacent_difference_config_tag {}; @@ -657,7 +833,7 @@ namespace detail { template -struct default_adjacent_difference_config_base_helper +struct default_adjacent_difference_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); @@ -669,11 +845,6 @@ struct default_adjacent_difference_config_base_helper ::rocprim::block_store_method::block_store_transpose>; }; -template -struct default_adjacent_difference_config_base - : default_adjacent_difference_config_base_helper::type -{}; - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_partition.hpp b/rocprim/include/rocprim/device/detail/device_partition.hpp index 7b49b2200..8fc63acc6 100644 --- a/rocprim/include/rocprim/device/detail/device_partition.hpp +++ b/rocprim/include/rocprim/device/detail/device_partition.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -688,6 +688,153 @@ ROCPRIM_DEVICE void load_selected_count(const size_t* const prev_selected_count, } } +template +class partition_values_helper +{ + +private: + ValueType values[ItemsPerThread]; + +public: + ROCPRIM_DEVICE void load(ValueIterator values_input, + unsigned int valid, + unsigned int is_global_last_block, + typename BlockLoadValueType::storage_type& load_values) + { + // Load values and sync threads + if(is_global_last_block) + { + BlockLoadValueType().load(values_input, values, valid, load_values); + } + else + { + BlockLoadValueType().load(values_input, values, load_values); + } + ::rocprim::syncthreads(); + } + + template + ROCPRIM_DEVICE void store(bool (&is_selected)[ItemsPerThread], + OffsetType (&output_indices)[ItemsPerThread], + OutputType values_output, + const size_t total_size, + const OffsetType selected_prefix, + const OffsetType selected_in_block, + ScatterStorageType& storage, + const unsigned int flat_block_id, + const unsigned int flat_block_thread_id, + const bool is_global_last_block, + const unsigned int valid_in_global_last_block, + size_t (&prev_selected_count_values)[1], + size_t prev_processed) + { + // Sync threads and store values + ::rocprim::syncthreads(); + partition_scatter(values, + is_selected, + output_indices, + values_output, + total_size, + selected_prefix, + selected_in_block, + storage, + flat_block_id, + flat_block_thread_id, + is_global_last_block, + valid_in_global_last_block, + prev_selected_count_values, + prev_processed); + } + + template + ROCPRIM_DEVICE void store(bool (&is_selected)[2][ItemsPerThread], + OffsetType (&output_indices)[ItemsPerThread], + OutputType values_output, + const size_t total_size, + const OffsetType selected_prefix, + const OffsetType selected_in_block, + ScatterStorageType& storage, + const unsigned int flat_block_id, + const unsigned int flat_block_thread_id, + const bool is_global_last_block, + const unsigned int valid_in_global_last_block, + size_t (&prev_selected_count_values)[2], + size_t prev_processed) + { + // Sync threads and store values + ::rocprim::syncthreads(); + partition_scatter(values, + is_selected, + output_indices, + values_output, + total_size, + selected_prefix, + selected_in_block, + storage, + flat_block_id, + flat_block_thread_id, + is_global_last_block, + valid_in_global_last_block, + prev_selected_count_values, + prev_processed); + } +}; + +template +class partition_values_helper +{ +public: + ROCPRIM_DEVICE void + load(ValueIterator, unsigned int, unsigned int, typename BlockLoadValueType::storage_type&) + {} + + template + ROCPRIM_DEVICE void store(bool[ItemsPerThread], + OffsetType[ItemsPerThread], + OutputType, + const size_t, + const OffsetType, + const OffsetType, + ScatterStorageType&, + const unsigned int, + const unsigned int, + const bool, + const unsigned int, + size_t[ItemsPerThread], + size_t) + {} + + template + ROCPRIM_DEVICE void store(bool[2][ItemsPerThread], + OffsetType[ItemsPerThread], + OutputType, + const size_t, + const OffsetType, + const OffsetType, + ScatterStorageType&, + const unsigned int, + const unsigned int, + const bool, + const unsigned int, + size_t[ItemsPerThread], + size_t) + {} +}; + template + values_helper; + values_helper.load(values_input + block_offset, + valid_in_global_last_block, + is_global_last_block, + storage.load_values); + // Load selection flags into is_selected, generate them using // input value and selection predicate, or generate them using // block_discontinuity primitive @@ -885,45 +1045,19 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void prev_selected_count_values, prev_processed); - static constexpr bool with_values = !std::is_same::value; - - if ROCPRIM_IF_CONSTEXPR (with_values) { - value_type values[items_per_thread]; - - ::rocprim::syncthreads(); // sync threads to reuse shared memory - if(is_global_last_block) - { - block_load_value_type().load(values_input + block_offset, - values, - valid_in_global_last_block, - storage.load_values); - } - else - { - block_load_value_type() - .load( - values_input + block_offset, - values, - storage.load_values - ); - } - ::rocprim::syncthreads(); // sync threads to reuse shared memory - - partition_scatter(values, - is_selected, - output_indices, - values_output, - total_size, - selected_prefix, - selected_in_block, - storage.exchange_values, - flat_block_id, - flat_block_thread_id, - is_global_last_block, - valid_in_global_last_block, - prev_selected_count_values, - prev_processed); - } + values_helper.store(is_selected, + output_indices, + values_output, + total_size, + selected_prefix, + selected_in_block, + storage.exchange_values, + flat_block_id, + flat_block_thread_id, + is_global_last_block, + valid_in_global_last_block, + prev_selected_count_values, + prev_processed); // Last block in grid stores number of selected values const bool is_last_block = flat_block_id == (number_of_blocks - 1); diff --git a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp index 16bc5b866..4c0e6f21a 100644 --- a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,6 @@ #include "../../block/block_load.hpp" #include "../../block/block_scan.hpp" #include "../../block/block_store.hpp" -#include "../../detail/match_result_type.hpp" #include "../../detail/various.hpp" #include "../../intrinsics/thread.hpp" #include "../../thread/thread_operators.hpp" @@ -51,7 +50,7 @@ using value_type_t = typename std::iterator_traits::value_type; template using accumulator_type_t = - typename detail::match_result_type, BinaryOp>::type; + typename invoke_result_binary_op, BinaryOp>::type; template using wrapped_type_t = rocprim::tuple; diff --git a/rocprim/include/rocprim/device/detail/device_scan.hpp b/rocprim/include/rocprim/device/detail/device_scan.hpp index b11a1cfe2..90b31f7bf 100644 --- a/rocprim/include/rocprim/device/detail/device_scan.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan.hpp @@ -93,23 +93,22 @@ template ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void lookback_scan_kernel_impl(InputIterator input, OutputIterator output, const size_t size, - ResultType initial_value, + AccType initial_value, BinaryFunction scan_op, LookbackScanState scan_state, const unsigned int number_of_blocks, - ResultType* previous_last_element = nullptr, - ResultType* new_last_element = nullptr, + AccType* previous_last_element = nullptr, + AccType* new_last_element = nullptr, bool override_first_value = false, bool save_last_value = false) { - using result_type = ResultType; - static_assert(std::is_same::value, + static_assert(std::is_same::value, "value_type of LookbackScanState must be result_type"); static constexpr scan_config_params params = device_params(); @@ -117,15 +116,14 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void constexpr auto items_per_thread = params.kernel_config.items_per_thread; constexpr unsigned int items_per_block = block_size * items_per_thread; - using block_load_type = ::rocprim:: - block_load; - using block_store_type = ::rocprim:: - block_store; - using block_scan_type - = ::rocprim::block_scan; + using block_load_type + = ::rocprim::block_load; + using block_store_type + = ::rocprim::block_store; + using block_scan_type = ::rocprim::block_scan; using lookback_scan_prefix_op_type - = lookback_scan_prefix_op; + = lookback_scan_prefix_op; ROCPRIM_SHARED_MEMORY union { @@ -140,7 +138,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void const auto valid_in_last_block = size - items_per_block * (number_of_blocks - 1); // For input values - result_type values[items_per_thread]; + AccType values[items_per_thread]; // load input values into values if(flat_block_id == (number_of_blocks - 1)) // last block @@ -165,12 +163,12 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void { if(Exclusive) initial_value - = scan_op(previous_last_element[0], static_cast(*(input - 1))); + = scan_op(previous_last_element[0], static_cast(*(input - 1))); else if(flat_block_thread_id == 0) values[0] = scan_op(previous_last_element[0], values[0]); } - result_type reduction; + AccType reduction; lookback_block_scan(values, // input/output initial_value, reduction, diff --git a/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp b/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp index fe481b0b2..1efb9c6fd 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp @@ -32,6 +32,7 @@ #include "../../detail/binary_op_wrappers.hpp" #include "../../intrinsics/thread.hpp" #include "../../types/tuple.hpp" +#include "rocprim/device/detail/device_config_helper.hpp" #include diff --git a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp index b1a8cf33b..52fd28086 100644 --- a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -486,21 +486,14 @@ struct DisabledWarpSortHelperConfig static constexpr unsigned int block_size = 1; }; -template -using select_warp_sort_helper_config_small_t - = std::conditional_t::value, - DisabledWarpSortHelperConfig, - WarpSortHelperConfig>; - -template -using select_warp_sort_helper_config_medium_t - = std::conditional_t::value, - DisabledWarpSortHelperConfig, - WarpSortHelperConfig>; +template +using select_warp_sort_helper_config_t + = std::conditional_t, + DisabledWarpSortHelperConfig>; template< class Config, @@ -705,12 +698,14 @@ void segmented_sort(KeysInputIterator keys_input, unsigned int begin_bit, unsigned int end_bit) { - constexpr unsigned int long_radix_bits = Config::long_radix_bits; - constexpr unsigned int short_radix_bits = Config::short_radix_bits; - constexpr unsigned int block_size = Config::sort::block_size; - constexpr unsigned int items_per_thread = Config::sort::items_per_thread; - constexpr unsigned int items_per_block = block_size * items_per_thread; - constexpr bool warp_sort_enabled = Config::warp_sort_config::enable_unpartitioned_warp_sort; + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int long_radix_bits = params.long_radix_bits; + static constexpr unsigned int short_radix_bits = params.short_radix_bits; + static constexpr unsigned int block_size = params.kernel_config.block_size; + static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; + static constexpr bool warp_sort_enabled = params.enable_unpartitioned_warp_sort; using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; @@ -731,7 +726,10 @@ void segmented_sort(KeysInputIterator keys_input, short_radix_bits, Descending >; using warp_sort_helper_type = segmented_warp_sort_helper< - select_warp_sort_helper_config_small_t, + select_warp_sort_helper_config_t, key_type, value_type, Descending>; @@ -798,7 +796,7 @@ void segmented_sort(KeysInputIterator keys_input, storage.single_block_helper ); } - else if(::rocprim::flat_block_thread_id() < Config::warp_sort_config::logical_warp_size_small) + else if(::rocprim::flat_block_thread_id() < params.warp_sort_config.logical_warp_size_small) { // Single warp segment warp_sort_helper_type().sort( @@ -837,11 +835,13 @@ void segmented_sort_large(KeysInputIterator keys_input, unsigned int begin_bit, unsigned int end_bit) { - constexpr unsigned int long_radix_bits = Config::long_radix_bits; - constexpr unsigned int short_radix_bits = Config::short_radix_bits; - constexpr unsigned int block_size = Config::sort::block_size; - constexpr unsigned int items_per_thread = Config::sort::items_per_thread; - constexpr unsigned int items_per_block = block_size * items_per_thread; + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int long_radix_bits = params.long_radix_bits; + static constexpr unsigned int short_radix_bits = params.short_radix_bits; + static constexpr unsigned int block_size = params.kernel_config.block_size; + static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; @@ -946,17 +946,101 @@ void segmented_sort_small(KeysInputIterator keys_input, unsigned int begin_bit, unsigned int end_bit) { - static constexpr unsigned int block_size = Config::block_size; - static constexpr unsigned int logical_warp_size = Config::logical_warp_size; - static_assert(block_size % logical_warp_size == 0, "logical_warp_size must be a divisor of block_size"); + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int block_size = params.warp_sort_config.block_size_small; + static constexpr unsigned int logical_warp_size + = params.warp_sort_config.logical_warp_size_small; + static_assert(block_size % logical_warp_size == 0, + "logical_warp_size must be a divisor of block_size"); + static constexpr unsigned int warps_per_block = block_size / logical_warp_size; + + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using warp_sort_helper_type = segmented_warp_sort_helper< + select_warp_sort_helper_config_t, + key_type, + value_type, + Descending>; + + ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; + + const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int logical_warp_id = ::rocprim::detail::logical_warp_id(); + const unsigned int segment_index = block_id * warps_per_block + logical_warp_id; + if(segment_index >= num_segments) + { + return; + } + + const unsigned int segment_id = segment_indices[segment_index]; + const unsigned int begin_offset = begin_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; + if(end_offset <= begin_offset) + { + return; + } + warp_sort_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); +} + +template +ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void segmented_sort_medium( + KeysInputIterator keys_input, + typename std::iterator_traits::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type* values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) +{ + static constexpr segmented_radix_sort_config_params params = device_params(); + + static constexpr unsigned int block_size = params.warp_sort_config.block_size_medium; + static constexpr unsigned int logical_warp_size + = params.warp_sort_config.logical_warp_size_medium; + static_assert(block_size % logical_warp_size == 0, + "logical_warp_size must be a divisor of block_size"); static constexpr unsigned int warps_per_block = block_size / logical_warp_size; using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; using warp_sort_helper_type = segmented_warp_sort_helper< - Config, key_type, value_type, Descending - >; + select_warp_sort_helper_config_t, + key_type, + value_type, + Descending>; ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; diff --git a/rocprim/include/rocprim/device/detail/device_transform.hpp b/rocprim/include/rocprim/device/detail/device_transform.hpp index e84ee0731..1f966e264 100644 --- a/rocprim/include/rocprim/device/detail/device_transform.hpp +++ b/rocprim/include/rocprim/device/detail/device_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,6 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../detail/match_result_type.hpp" #include "../../intrinsics.hpp" #include "../../functional.hpp" @@ -45,7 +44,7 @@ namespace detail template struct unpack_binary_op { - using result_type = typename ::rocprim::detail::invoke_result::type; + using result_type = typename ::rocprim::invoke_result::type; ROCPRIM_HOST_DEVICE inline unpack_binary_op() = default; diff --git a/rocprim/include/rocprim/device/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/device_adjacent_difference.hpp index 917bf4328..13476ce43 100644 --- a/rocprim/include/rocprim/device/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/device_adjacent_difference.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -154,9 +154,10 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, // next block, otherwise the first item is needed for the previous block const auto offset = items_per_block - (Right ? 0 : 1); - const auto block_starts_iter = make_transform_iterator( - rocprim::make_counting_iterator(std::size_t{0}), - [=, base = input + offset](std::size_t i) { return base[i * items_per_block]; }); + const auto block_starts_iter + = make_transform_iterator(rocprim::make_counting_iterator(std::size_t{0}), + [=, base = input + offset](std::size_t i) -> value_type + { return base[i * items_per_block]; }); const hipError_t error = ::rocprim::transform(block_starts_iter, previous_values, @@ -244,26 +245,26 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param input - iterator to the input range -/// \param output - iterator to the output range, must have any overlap with input -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the input range +/// \param output iterator to the output range, must not have any overlap with input. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -340,23 +341,23 @@ hipError_t adjacent_difference(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param values - iterator to the range values, will be overwritten with the results -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param values iterator to the range values, will be overwritten with the results +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -379,6 +380,65 @@ hipError_t adjacent_difference_inplace(void* const temporary_storage, temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); } +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the left item. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be +/// `adjacent_difference_config` or a class derived from it. +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it +/// +/// \param temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the range values +/// \param output iterator to the output range. Allowed to point to the same elements as `input`. +/// Only complete overlap or no overlap at all is allowed between `input` and `output`. In other words +/// writing to `output[i]` is only allowed to overwrite `input[i]`, any other element must not be changed. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `hipSuccess` (0) on success, otherwise the HIP runtime error of +/// type `hipError_t` +/// +/// \note This function has to perform an extra copy due to (potentially) writing its values in-place. If it is known that `input` and `output` +/// don't overlap then adjacent_difference should be preferred as it avoids this extra copy. +template> +hipError_t adjacent_difference_inplace(void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op = BinaryFunction{}, + const hipStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = true; + static constexpr bool right = false; + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + input, + output, + size, + op, + stream, + debug_synchronous); +} + /// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements /// in device accessible memory. Writes the output to the position of the right item. /// @@ -393,26 +453,26 @@ hipError_t adjacent_difference_inplace(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the input range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam OutputIt - [inferred] random-access iterator type of the output range. Must meet the +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param input - iterator to the input range -/// \param output - iterator to the output range, must have any overlap with input -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the input range +/// \param output iterator to the output range, must not have any overlap with input. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -489,23 +549,23 @@ hipError_t adjacent_difference_right(void* const temporary_storage, /// } /// \endcode /// -/// \tparam Config - [optional] configuration of the primitive. It has to be +/// \tparam Config [optional] configuration of the primitive. It has to be /// `adjacent_difference_config` or a class derived from it. -/// \tparam InputIt - [inferred] random-access iterator type of the value range. Must meet the +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam BinaryFunction - [inferred] binary operation function object that will be applied to +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to /// consecutive items. The signature of the function should be equivalent to the following: /// `U f(const T1& a, const T2& b)`. The signature does not need to have /// `const &`, but function object must not modify the object passed to it -/// \param temporary_storage - pointer to a device-accessible temporary storage. When +/// \param temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// `storage_size` and function returns without performing the scan operation -/// \param storage_size - reference to a size (in bytes) of `temporary_storage` -/// \param values - iterator to the range values, will be overwritten with the results -/// \param size - number of items in the input -/// \param op - [optional] the binary operation to apply -/// \param stream - [optional] HIP stream object. Default is `0` (the default stream) -/// \param debug_synchronous - [optional] If true, synchronization after every kernel +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param values iterator to the range values, will be overwritten with the results +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors and extra debugging info is printed to the /// standard output. Default value is `false` /// @@ -528,6 +588,64 @@ hipError_t adjacent_difference_right_inplace(void* const temporary_stor temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); } +/// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements +/// in device accessible memory. Writes the output to the position of the right item. +/// +/// \tparam Config [optional] configuration of the primitive. It has to be +/// `adjacent_difference_config` or a class derived from it. +/// \tparam InputIt [inferred] random-access iterator type of the value range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIt [inferred] random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction [inferred] binary operation function object that will be applied to +/// consecutive items. The signature of the function should be equivalent to the following: +/// `U f(const T1& a, const T2& b)`. The signature does not need to have +/// `const &`, but function object must not modify the object passed to it + +/// \param temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the scan operation +/// \param storage_size reference to a size (in bytes) of `temporary_storage` +/// \param input iterator to the range values, will be overwritten with the results +/// \param output iterator to the output range. Allowed to point to the same elements as `input`. +/// Only complete overlap or no overlap at all is allowed between `input` and `output`. In other words +/// writing to `output[i]` is only allowed to overwrite `input[i]`, any other element must not be changed. +/// \param size number of items in the input +/// \param op [optional] the binary operation to apply +/// \param stream [optional] HIP stream object. Default is `0` (the default stream) +/// \param debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false` +/// +/// \return `hipSuccess` (0) on success, otherwise the HIP runtime error of +/// type `hipError_t` +/// \note This function has to perform an extra copy due to (potentially) writing its values in-place. If it is known that `input` and `output` +/// don't overlap then adjacent_difference_right should be preferred as it avoids this extra copy. +template> +hipError_t adjacent_difference_right_inplace(void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + const std::size_t size, + const BinaryFunction op = BinaryFunction{}, + const hipStream_t stream = 0, + const bool debug_synchronous = false) +{ + static constexpr bool in_place = true; + static constexpr bool right = true; + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + input, + output, + size, + op, + stream, + debug_synchronous); +} + /// @} // end of group devicemodule diff --git a/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp b/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp index 0299484f1..e68391dd0 100644 --- a/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp +++ b/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp @@ -83,6 +83,11 @@ struct wrapped_adjacent_difference_config }; #ifndef DOXYGEN_SHOULD_SKIP_THIS +template +template +constexpr adjacent_difference_config_params + wrapped_adjacent_difference_config:: + architecture_config::params; template template constexpr adjacent_difference_config_params diff --git a/rocprim/include/rocprim/device/device_reduce.hpp b/rocprim/include/rocprim/device/device_reduce.hpp index 649e1c583..f72339b15 100644 --- a/rocprim/include/rocprim/device/device_reduce.hpp +++ b/rocprim/include/rocprim/device/device_reduce.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,6 @@ #include "config_types.hpp" #include "../config.hpp" -#include "../detail/match_result_type.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" @@ -112,9 +111,8 @@ hipError_t reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::detail::match_result_type< - input_type, BinaryFunction - >::type; + using result_type = + typename ::rocprim::invoke_result_binary_op::type; using config = wrapped_reduce_config; diff --git a/rocprim/include/rocprim/device/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/device_reduce_by_key.hpp index 315f75668..db63d50d5 100644 --- a/rocprim/include/rocprim/device/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/device_reduce_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -30,7 +30,6 @@ #include "detail/lookback_scan_state.hpp" #include "../config.hpp" -#include "../detail/match_result_type.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" #include "../functional.hpp" diff --git a/rocprim/include/rocprim/device/device_scan.hpp b/rocprim/include/rocprim/device/device_scan.hpp index 2cb648dad..305bd3416 100644 --- a/rocprim/include/rocprim/device/device_scan.hpp +++ b/rocprim/include/rocprim/device/device_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -51,10 +51,10 @@ template + class AccType> ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void single_scan_kernel_impl(InputIterator input, const size_t input_size, - ResultType initial_value, + AccType initial_value, OutputIterator output, BinaryFunction scan_op) { @@ -63,14 +63,11 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void single_scan_kernel_impl(InputIterator constexpr unsigned int block_size = params.kernel_config.block_size; constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; - using result_type = ResultType; - - using block_load_type = ::rocprim:: - block_load; - using block_store_type = ::rocprim:: - block_store; - using block_scan_type - = ::rocprim::block_scan; + using block_load_type + = ::rocprim::block_load; + using block_store_type + = ::rocprim::block_store; + using block_scan_type = ::rocprim::block_scan; ROCPRIM_SHARED_MEMORY union { @@ -79,7 +76,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void single_scan_kernel_impl(InputIterator typename block_scan_type::storage_type scan; } storage; - result_type values[items_per_thread]; + AccType values[items_per_thread]; // load input values into values block_load_type().load(input, values, input_size, *(input), storage.load); ::rocprim::syncthreads(); // sync threads to reuse shared memory @@ -100,7 +97,8 @@ template + class InitValueType, + class AccType> ROCPRIM_KERNEL __launch_bounds__(device_params().kernel_config.block_size) void single_scan_kernel( InputIterator input, @@ -111,7 +109,7 @@ ROCPRIM_KERNEL { single_scan_kernel_impl(input, size, - get_input_value(initial_value), + static_cast(get_input_value(initial_value)), output, scan_op); } @@ -124,32 +122,34 @@ template ROCPRIM_KERNEL __launch_bounds__(device_params().kernel_config.block_size) void lookback_scan_kernel( - InputIterator input, - OutputIterator output, - const size_t size, - const InitValueType initial_value, - BinaryFunction scan_op, - LookBackScanState lookback_scan_state, - const unsigned int number_of_blocks, - input_type_t* previous_last_element = nullptr, - input_type_t* new_last_element = nullptr, - bool override_first_value = false, - bool save_last_value = false) + InputIterator input, + OutputIterator output, + const size_t size, + const InitValueType initial_value, + BinaryFunction scan_op, + LookBackScanState lookback_scan_state, + const unsigned int number_of_blocks, + AccType* previous_last_element = nullptr, + AccType* new_last_element = nullptr, + bool override_first_value = false, + bool save_last_value = false) { - lookback_scan_kernel_impl(input, - output, - size, - get_input_value(initial_value), - scan_op, - lookback_scan_state, - number_of_blocks, - previous_last_element, - new_last_element, - override_first_value, - save_last_value); + lookback_scan_kernel_impl( + input, + output, + size, + static_cast(get_input_value(initial_value)), + scan_op, + lookback_scan_state, + number_of_blocks, + previous_last_element, + new_last_element, + override_first_value, + save_last_value); } #define ROCPRIM_DETAIL_HIP_SYNC(name, size, start) \ @@ -183,7 +183,8 @@ template + class BinaryFunction, + class AccType> inline auto scan_impl(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -194,9 +195,7 @@ inline auto scan_impl(void* temporary_storage, const hipStream_t stream, bool debug_synchronous) { - using real_init_value_type = input_type_t; - - using config = wrapped_scan_config; + using config = wrapped_scan_config; detail::target_arch target_arch; hipError_t result = host_target_arch(stream, target_arch); @@ -206,8 +205,8 @@ inline auto scan_impl(void* temporary_storage, } const scan_config_params params = dispatch_target_arch(target_arch); - using scan_state_type = detail::lookback_scan_state; - using scan_state_with_sleep_type = detail::lookback_scan_state; + using scan_state_type = detail::lookback_scan_state; + using scan_state_with_sleep_type = detail::lookback_scan_state; const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -222,9 +221,9 @@ inline auto scan_impl(void* temporary_storage, unsigned int number_of_blocks = (limited_size + items_per_block - 1)/items_per_block; // Pointer to array with block_prefixes - void* scan_state_storage; - real_init_value_type* previous_last_element; - real_init_value_type* new_last_element; + void* scan_state_storage; + AccType* previous_last_element; + AccType* new_last_element; detail::temp_storage::layout layout{}; hipError_t layout_result @@ -304,24 +303,14 @@ inline auto scan_impl(void* temporary_storage, if(std::string(prop.gcnArchName).find("908") != std::string::npos && asicRevision < 2) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME(init_lookback_scan_state_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - scan_state_with_sleep, - number_of_blocks); + init_lookback_scan_state_kernel + <<>>(scan_state_with_sleep, + number_of_blocks); } else { - hipLaunchKernelGGL( - HIP_KERNEL_NAME(init_lookback_scan_state_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - scan_state, - number_of_blocks); + init_lookback_scan_state_kernel + <<>>(scan_state, + number_of_blocks); } ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_lookback_scan_state_kernel", number_of_blocks, start) @@ -329,30 +318,25 @@ inline auto scan_impl(void* temporary_storage, grid_size = number_of_blocks; if(std::string(prop.gcnArchName).find("908") != std::string::npos && asicRevision < 2) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - lookback_scan_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - input + offset, - output + offset, - current_size, - initial_value, - scan_op, - scan_state_with_sleep, - number_of_blocks, - previous_last_element, - new_last_element, - i != size_t(0), - number_of_launch > 1); + lookback_scan_kernel + <<>>(input + offset, + output + offset, + current_size, + initial_value, + scan_op, + scan_state_with_sleep, + number_of_blocks, + previous_last_element, + new_last_element, + i != size_t(0), + number_of_launch > 1); } else { @@ -366,30 +350,25 @@ inline auto scan_impl(void* temporary_storage, std::cout << "items_per_block " << items_per_block << '\n'; } - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - lookback_scan_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - input + offset, - output + offset, - current_size, - initial_value, - scan_op, - scan_state, - number_of_blocks, - previous_last_element, - new_last_element, - i != size_t(0), - number_of_launch > 1); + lookback_scan_kernel + <<>>(input + offset, + output + offset, + current_size, + initial_value, + scan_op, + scan_state, + number_of_blocks, + previous_last_element, + new_last_element, + i != size_t(0), + number_of_launch > 1); } ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start) @@ -399,7 +378,7 @@ inline auto scan_impl(void* temporary_storage, hipError_t error = ::rocprim::transform(new_last_element, previous_last_element, 1, - ::rocprim::identity(), + ::rocprim::identity(), stream, debug_synchronous); if(error != hipSuccess) return error; @@ -414,24 +393,17 @@ inline auto scan_impl(void* temporary_storage, std::cout << "block_size " << block_size << '\n'; std::cout << "number of blocks " << number_of_blocks << '\n'; std::cout << "items_per_block " << items_per_block << '\n'; + start = std::chrono::high_resolution_clock::now(); } - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(single_scan_kernel), - dim3(1), - dim3(block_size), - 0, - stream, - input, - size, - initial_value, - output, - scan_op); + single_scan_kernel + <<>>(input, size, initial_value, output, scan_op); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("single_scan_kernel", size, start); } return hipSuccess; @@ -455,7 +427,7 @@ inline auto scan_impl(void* temporary_storage, /// if \p temporary_storage in a null pointer. /// * Ranges specified by \p input and \p output must have at least \p size elements. /// * By default, the input type is used for accumulation. A custom type -/// can be specified using rocprim::transform_iterator, see the example below. +/// can be specified using the \p AccType type parameter, see the example below. /// /// \tparam Config - [optional] configuration of the primitive, has to be \p scan_config or a class derived from it. /// \tparam InputIterator - random-access iterator type of the input range. Must meet the @@ -464,6 +436,8 @@ inline auto scan_impl(void* temporary_storage, /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. /// \tparam BinaryFunction - type of binary function used for scan. Default type /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is value type of the input iterator. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -526,31 +500,35 @@ inline auto scan_impl(void* temporary_storage, /// short * input; /// int * output; /// -/// // Use a transform iterator to specify a custom accumulator type -/// auto input_iterator = rocprim::make_transform_iterator( -/// input, [] __device__ (T in) { return static_cast(in); }); -/// /// size_t temporary_storage_size_bytes; /// void * temporary_storage_ptr = nullptr; -/// // Use the transform iterator +/// /// rocprim::inclusive_scan( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// input_iterator, output, input_size, rocprim::plus() +/// input, output, input_size, rocprim::plus() /// ); /// /// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); /// -/// rocprim::inclusive_scan( -/// temporary_storage_ptr, temporary_storage_size_bytes, -/// input_iterator, output, input_size, rocprim::plus() -/// ); +/// // Use type parameter to set custom accumulator type +/// rocprim::inclusive_scan, +/// int>(temporary_storage_ptr, +/// temporary_storage_size_bytes, +/// input_iterator, +/// output, +/// input_size, +/// rocprim::plus()); /// \endcode /// \endparblock template::value_type>> + = ::rocprim::plus::value_type>, + class AccType = typename std::iterator_traits::value_type> inline hipError_t inclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -560,17 +538,18 @@ inline hipError_t inclusive_scan(void* temporary_storage, const hipStream_t stream = 0, bool debug_synchronous = false) { - using input_type = typename std::iterator_traits::value_type; // input_type() is a dummy initial value (not used) - return detail::scan_impl(temporary_storage, - storage_size, - input, - output, - input_type(), - size, - scan_op, - stream, - debug_synchronous); + return detail:: + scan_impl( + temporary_storage, + storage_size, + input, + output, + AccType{}, + size, + scan_op, + stream, + debug_synchronous); } /// \brief Parallel exclusive scan primitive for device level. @@ -594,6 +573,8 @@ inline hipError_t inclusive_scan(void* temporary_storage, /// \tparam InitValueType - type of the initial value. /// \tparam BinaryFunction - type of binary function used for scan. Default type /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is 'InitValueType', unless it's 'rocprim::future_value'. Then it will be the wrapped input type. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -654,7 +635,7 @@ inline hipError_t inclusive_scan(void* temporary_storage, /// temporary_storage_ptr, temporary_storage_size_bytes, /// input, output, start_value, input_size, min_op /// ); -/// // output: [9, 4, 7, 6, 2, 2, 1, 1] +/// // output: [9, 4, 4, 4, 2, 2, 1, 1] /// \endcode /// \endparblock template::value_type>> + = ::rocprim::plus::value_type>, + class AccType = detail::input_type_t> inline hipError_t exclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -673,15 +655,21 @@ inline hipError_t exclusive_scan(void* temporary_storage, const hipStream_t stream = 0, bool debug_synchronous = false) { - return detail::scan_impl(temporary_storage, - storage_size, - input, - output, - initial_value, - size, - scan_op, - stream, - debug_synchronous); + return detail::scan_impl(temporary_storage, + storage_size, + input, + output, + initial_value, + size, + scan_op, + stream, + debug_synchronous); } /// @} diff --git a/rocprim/include/rocprim/device/device_scan_by_key.hpp b/rocprim/include/rocprim/device/device_scan_by_key.hpp index 8d8c4fecb..73d5fcb0b 100644 --- a/rocprim/include/rocprim/device/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/device_scan_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -54,31 +54,32 @@ template + typename AccType> void __global__ __launch_bounds__(device_params().kernel_config.block_size) - device_scan_by_key_kernel(const KeyInputIterator keys, - const InputIterator values, - const OutputIterator output, - const InitialValueType initial_value, - const CompareFunction compare, - const BinaryFunction scan_op, - const LookbackScanState scan_state, - const size_t size, - const size_t starting_block, - const size_t number_of_blocks, - const ::rocprim::tuple* const previous_last_value) + device_scan_by_key_kernel(const KeyInputIterator keys, + const InputIterator values, + const OutputIterator output, + const InitialValueType initial_value, + const CompareFunction compare, + const BinaryFunction scan_op, + const LookbackScanState scan_state, + const size_t size, + const size_t starting_block, + const size_t number_of_blocks, + const ::rocprim::tuple* const previous_last_value) { - device_scan_by_key_kernel_impl(keys, - values, - output, - get_input_value(initial_value), - compare, - scan_op, - scan_state, - size, - starting_block, - number_of_blocks, - previous_last_value); + device_scan_by_key_kernel_impl( + keys, + values, + output, + static_cast(get_input_value(initial_value)), + compare, + scan_op, + scan_state, + size, + starting_block, + number_of_blocks, + previous_last_value); } #define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ @@ -106,7 +107,8 @@ template + typename CompareFunction, + typename AccType> inline hipError_t scan_by_key_impl(void* const temporary_storage, size_t& storage_size, KeysInputIterator keys, @@ -119,10 +121,9 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, const hipStream_t stream, const bool debug_synchronous) { - using key_type = typename std::iterator_traits::value_type; - using real_init_value_type = input_type_t; + using key_type = typename std::iterator_traits::value_type; - using config = wrapped_scan_by_key_config; + using config = wrapped_scan_by_key_config; detail::target_arch target_arch; hipError_t result = host_target_arch(stream, target_arch); @@ -132,7 +133,7 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, } const scan_by_key_config_params params = dispatch_target_arch(target_arch); - using wrapped_type = ::rocprim::tuple; + using wrapped_type = ::rocprim::tuple; using scan_state_type = detail::lookback_scan_state; using scan_state_with_sleep_type = detail::lookback_scan_state; @@ -284,7 +285,7 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, keys + offset, input + offset, output + offset, - static_cast(initial_value), + initial_value, compare, scan_op, scan_state, @@ -331,6 +332,8 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. /// \tparam KeyCompareFunction - type of binary function used to determine keys equality. Default type /// is \p rocprim::equal_to, where \p T is a \p value_type of \p KeysInputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is value type of the input iterator. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -402,7 +405,8 @@ template::value_type>, typename KeyCompareFunction - = ::rocprim::equal_to::value_type>> + = ::rocprim::equal_to::value_type>, + typename AccType = typename std::iterator_traits::value_type> inline hipError_t inclusive_scan_by_key(void* const temporary_storage, size_t& storage_size, const KeysInputIterator keys_input, @@ -416,17 +420,25 @@ inline hipError_t inclusive_scan_by_key(void* const temporary_sto const bool debug_synchronous = false) { using value_type = typename std::iterator_traits::value_type; - return detail::scan_by_key_impl(temporary_storage, - storage_size, - keys_input, - values_input, - values_output, - value_type(), - size, - scan_op, - key_compare_op, - stream, - debug_synchronous); + return detail::scan_by_key_impl(temporary_storage, + storage_size, + keys_input, + values_input, + values_output, + value_type(), + size, + scan_op, + key_compare_op, + stream, + debug_synchronous); } /// \brief Parallel exclusive scan-by-key primitive for device level. @@ -455,6 +467,8 @@ inline hipError_t inclusive_scan_by_key(void* const temporary_sto /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. /// \tparam KeyCompareFunction - type of binary function used to determine keys equality. Default type /// is \p rocprim::equal_to, where \p T is a \p value_type of \p KeysInputIterator. +/// \tparam AccType - accumulator type used to propagate the scanned values. Default type +/// is 'InitValueType', unless it's 'rocprim::future_value'. Then it will be the wrapped input type. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -530,7 +544,8 @@ template::value_type>, typename KeyCompareFunction - = ::rocprim::equal_to::value_type>> + = ::rocprim::equal_to::value_type>, + typename AccType = detail::input_type_t> inline hipError_t exclusive_scan_by_key(void* const temporary_storage, size_t& storage_size, const KeysInputIterator keys_input, @@ -544,17 +559,25 @@ inline hipError_t exclusive_scan_by_key(void* const temporary_sto const hipStream_t stream = 0, const bool debug_synchronous = false) { - return detail::scan_by_key_impl(temporary_storage, - storage_size, - keys_input, - values_input, - values_output, - initial_value, - size, - scan_op, - key_compare_op, - stream, - debug_synchronous); + return detail::scan_by_key_impl(temporary_storage, + storage_size, + keys_input, + values_input, + values_output, + initial_value, + size, + scan_op, + key_compare_op, + stream, + debug_synchronous); } /// @} diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 576789f06..58089b810 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -25,10 +25,12 @@ #include #include #include +#include #include "../config.hpp" -#include "../detail/various.hpp" #include "../detail/radix_sort.hpp" +#include "../detail/various.hpp" +#include "config_types.hpp" #include "../intrinsics.hpp" #include "../functional.hpp" @@ -49,31 +51,28 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class Config, - bool Descending, - unsigned int BlockSize, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator, - class OffsetIterator -> +template ROCPRIM_KERNEL -__launch_bounds__(BlockSize) -void segmented_sort_kernel(KeysInputIterator keys_input, - typename std::iterator_traits::value_type * keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits::value_type * values_tmp, - ValuesOutputIterator values_output, - bool to_output, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, - unsigned int begin_bit, - unsigned int end_bit) + __launch_bounds__(device_params().kernel_config.block_size) void segmented_sort_kernel( + KeysInputIterator keys_input, + typename std::iterator_traits::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type* values_tmp, + ValuesOutputIterator values_output, + bool to_output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) { segmented_sort( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -84,33 +83,34 @@ void segmented_sort_kernel(KeysInputIterator keys_input, ); } -template< - class Config, - bool Descending, - unsigned int BlockSize, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator, - class SegmentIndexIterator, - class OffsetIterator -> -ROCPRIM_KERNEL -__launch_bounds__(BlockSize) -void segmented_sort_large_kernel(KeysInputIterator keys_input, - typename std::iterator_traits::value_type * keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits::value_type * values_tmp, - ValuesOutputIterator values_output, - bool to_output, - SegmentIndexIterator segment_indices, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, - unsigned int begin_bit, - unsigned int end_bit) +template +ROCPRIM_KERNEL __launch_bounds__( + device_params() + .kernel_config + .block_size) void segmented_sort_large_kernel(KeysInputIterator keys_input, + typename std::iterator_traits< + KeysInputIterator>::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits< + ValuesInputIterator>::value_type* + values_tmp, + ValuesOutputIterator values_output, + bool to_output, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int long_iterations, + unsigned int short_iterations, + unsigned int begin_bit, + unsigned int end_bit) { segmented_sort_large( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -122,28 +122,33 @@ void segmented_sort_large_kernel(KeysInputIterator keys_input, } template -ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_kernel( - KeysInputIterator keys_input, - typename std::iterator_traits::value_type* keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits::value_type* values_tmp, - ValuesOutputIterator values_output, - bool to_output, - unsigned int num_segments, - SegmentIndexIterator segment_indices, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int begin_bit, - unsigned int end_bit) +ROCPRIM_KERNEL __launch_bounds__( + device_params() + .warp_sort_config + .block_size_small) void segmented_sort_small_kernel(KeysInputIterator keys_input, + typename std::iterator_traits< + KeysInputIterator>::value_type* + keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits< + ValuesInputIterator>::value_type* + values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) { segmented_sort_small( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -153,6 +158,50 @@ ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_ ); } +template +ROCPRIM_KERNEL __launch_bounds__( + device_params() + .warp_sort_config + .block_size_medium) void segmented_sort_medium_kernel(KeysInputIterator keys_input, + typename std::iterator_traits< + KeysInputIterator>::value_type* + keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits< + ValuesInputIterator>::value_type* + values_tmp, + ValuesOutputIterator values_output, + bool to_output, + unsigned int num_segments, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int begin_bit, + unsigned int end_bit) +{ + segmented_sort_medium(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + num_segments, + segment_indices, + begin_offsets, + end_offsets, + begin_bit, + end_bit); +} + #define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ { \ auto _error = hipGetLastError(); \ @@ -168,42 +217,12 @@ ROCPRIM_KERNEL __launch_bounds__(BlockSize) void segmented_sort_small_or_medium_ } \ } -struct TwoWayPartitioner +struct Partitioner { - template - hipError_t operator()(void* temporary_storage, - size_t& storage_size, - InputIterator input, - FirstOutputIterator output_first_part, - SecondOutputIterator /*output_second_part*/, - UnselectedOutputIterator /*output_unselected*/, - SelectedCountOutputIterator selected_count_output, - const size_t size, - FirstUnaryPredicate select_first_part_op, - SecondUnaryPredicate /*select_second_part_op*/, - const hipStream_t stream, - const bool debug_synchronous) - { - return partition(temporary_storage, - storage_size, - input, - output_first_part, - selected_count_output, - size, - select_first_part_op, - stream, - debug_synchronous); - } -}; + bool three_way_partitioning; + + Partitioner(bool three_way_part) : three_way_partitioning(three_way_part) {} -struct ThreeWayPartitioner -{ template - >; + using config = wrapped_segmented_radix_sort_config; + + detail::target_arch target_arch; + hipError_t result = detail::host_target_arch(stream, target_arch); + if(result != hipSuccess) + { + return result; + } + + const detail::segmented_radix_sort_config_params params + = detail::dispatch_target_arch(target_arch); static constexpr bool with_values = !std::is_same::value; - static constexpr bool partitioning_allowed = - !std::is_same::value; - static constexpr unsigned int max_small_segment_length - = config::warp_sort_config::items_per_thread_small - * config::warp_sort_config::logical_warp_size_small; - static constexpr unsigned int small_segments_per_block - = config::warp_sort_config::block_size_small - / config::warp_sort_config::logical_warp_size_small; - static constexpr unsigned int max_medium_segment_length - = config::warp_sort_config::items_per_thread_medium - * config::warp_sort_config::logical_warp_size_medium; - static constexpr unsigned int medium_segments_per_block - = config::warp_sort_config::block_size_medium - / config::warp_sort_config::logical_warp_size_medium; - static_assert( - max_small_segment_length <= max_medium_segment_length, - "The max length of small segments cannot be higher than the max length of medium segments"); - // Don't waste cycles on 3-way partitioning, if the small and medium segments are equal length - static constexpr bool three_way_partitioning - = max_small_segment_length < max_medium_segment_length; - using partitioner_type - = std::conditional_t; - partitioner_type partitioner; + const bool partitioning_allowed = params.warp_sort_config.partitioning_allowed; + const unsigned int max_small_segment_length = params.warp_sort_config.items_per_thread_small + * params.warp_sort_config.logical_warp_size_small; + const unsigned int small_segments_per_block = params.warp_sort_config.block_size_small + / params.warp_sort_config.logical_warp_size_small; + const unsigned int max_medium_segment_length + = params.warp_sort_config.items_per_thread_medium + * params.warp_sort_config.logical_warp_size_medium; + const unsigned int medium_segments_per_block + = params.warp_sort_config.block_size_medium + / params.warp_sort_config.logical_warp_size_medium; + + const bool three_way_partitioning = max_small_segment_length < max_medium_segment_length; + Partitioner partitioner(three_way_partitioning); const auto large_segment_selector = [=](const unsigned int segment_index) mutable -> bool { @@ -325,20 +357,22 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, const bool with_double_buffer = keys_tmp != nullptr; const unsigned int bits = end_bit - begin_bit; - const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, config::long_radix_bits); + const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, params.long_radix_bits); const bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; is_result_in_output = (iterations % 2 == 0) != to_output; - const unsigned int radix_bits_diff = config::long_radix_bits - config::short_radix_bits; - const unsigned int short_iterations = radix_bits_diff != 0 - ? ::rocprim::min(iterations, (config::long_radix_bits * iterations - bits) / radix_bits_diff) - : 0; + const unsigned int radix_bits_diff = params.long_radix_bits - params.short_radix_bits; + const unsigned int short_iterations + = radix_bits_diff != 0 + ? ::rocprim::min(iterations, + (params.long_radix_bits * iterations - bits) / radix_bits_diff) + : 0; const unsigned int long_iterations = iterations - short_iterations; - const bool do_partitioning = partitioning_allowed - && segments >= config::warp_sort_config::partitioning_threshold; + const bool do_partitioning + = partitioning_allowed && segments >= params.warp_sort_config.partitioning_threshold; - const size_t medium_segment_indices_size = three_way_partitioning ? segments : 0; - static constexpr size_t segment_count_output_size = three_way_partitioning ? 2 : 1; - const size_t segment_count_output_bytes + const size_t medium_segment_indices_size = three_way_partitioning ? segments : 0; + const size_t segment_count_output_size = three_way_partitioning ? 2 : 1; + const size_t segment_count_output_bytes = segment_count_output_size * sizeof(segment_index_type); segment_index_type* large_segment_indices_output{}; @@ -412,10 +446,14 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, std::cout << "long_iterations " << long_iterations << '\n'; std::cout << "short_iterations " << short_iterations << '\n'; std::cout << "do_partitioning " << do_partitioning << '\n'; - std::cout << "config::sort::block_size: " << config::sort::block_size << '\n'; - std::cout << "config::sort::items_per_thread: " << config::sort::items_per_thread << '\n'; + std::cout << "params.kernel_config.block_size: " << params.kernel_config.block_size << '\n'; + std::cout << "params.kernel_config.items_per_thread: " + << params.kernel_config.items_per_thread << '\n'; hipError_t error = hipStreamSynchronize(stream); - if(error != hipSuccess) return error; + if(error != hipSuccess) + { + return error; + } } if(!with_double_buffer) @@ -443,8 +481,9 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, { return result; } - segment_index_type segment_counts[segment_count_output_size]{}; - result = detail::memcpy_and_sync(&segment_counts, + std::vector segment_counts(segment_count_output_size, + segment_index_type{}); + result = detail::memcpy_and_sync(segment_counts.data(), segment_count_output, segment_count_output_bytes, hipMemcpyDeviceToHost, @@ -466,15 +505,25 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, { std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(segmented_sort_large_kernel), - dim3(large_segment_count), dim3(config::sort::block_size), 0, stream, - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, large_segment_indices_output, - begin_offsets, end_offsets, - long_iterations, short_iterations, - begin_bit, end_bit - ); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_large_kernel), + dim3(large_segment_count), + dim3(params.kernel_config.block_size), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + large_segment_indices_output, + begin_offsets, + end_offsets, + long_iterations, + short_iterations, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments", large_segment_count, start) @@ -486,29 +535,24 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - segmented_sort_small_or_medium_kernel< - select_warp_sort_helper_config_medium_t, - Descending, - config::warp_sort_config::block_size_medium>), - dim3(medium_segment_grid_size), - dim3(config::warp_sort_config::block_size_medium), - 0, - stream, - keys_input, - keys_tmp, - keys_output, - values_input, - values_tmp, - values_output, - is_result_in_output, - medium_segment_count, - medium_segment_indices_output, - begin_offsets, - end_offsets, - begin_bit, - end_bit); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_medium_kernel), + dim3(medium_segment_grid_size), + dim3(params.warp_sort_config.block_size_medium), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + is_result_in_output, + medium_segment_count, + medium_segment_indices_output, + begin_offsets, + end_offsets, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:medium_segments", medium_segment_count, start) @@ -519,29 +563,24 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, small_segments_per_block); std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - segmented_sort_small_or_medium_kernel< - select_warp_sort_helper_config_small_t, - Descending, - config::warp_sort_config::block_size_small>), - dim3(small_segment_grid_size), - dim3(config::warp_sort_config::block_size_small), - 0, - stream, - keys_input, - keys_tmp, - keys_output, - values_input, - values_tmp, - values_output, - is_result_in_output, - small_segment_count, - small_segment_indices_output, - begin_offsets, - end_offsets, - begin_bit, - end_bit); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_small_kernel), + dim3(small_segment_grid_size), + dim3(params.warp_sort_config.block_size_small), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + is_result_in_output, + small_segment_count, + small_segment_indices_output, + begin_offsets, + end_offsets, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:small_segments", small_segment_count, start) @@ -551,15 +590,24 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, { std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(segmented_sort_kernel), - dim3(segments), dim3(config::sort::block_size), 0, stream, - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, - begin_offsets, end_offsets, - long_iterations, short_iterations, - begin_bit, end_bit - ); + hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_kernel), + dim3(segments), + dim3(params.kernel_config.block_size), + 0, + stream, + keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offsets, + end_offsets, + long_iterations, + short_iterations, + begin_bit, + end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start) } return hipSuccess; diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp index 7d5b248a8..62def5dff 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,88 +29,14 @@ #include "../functional.hpp" #include "config_types.hpp" +#include "detail/config/device_segmented_radix_sort.hpp" +#include "detail/device_config_helper.hpp" /// \addtogroup primitivesmodule_deviceconfigs /// @{ BEGIN_ROCPRIM_NAMESPACE -/// \brief Configuration of the warp sort part of the device segmented radix sort operation. -/// Short enough segments are processed on warp level. -/// -/// \tparam LogicalWarpSizeSmall - number of threads in the logical warp of the kernel -/// that processes small segments. -/// \tparam ItemsPerThreadSmall - number of items processed by a thread in the kernel that processes -/// small segments. -/// \tparam BlockSizeSmall - number of threads per block in the kernel which processes the small segments. -/// \tparam PartitioningThreshold - if the number of segments is at least this threshold, the -/// segments are partitioned to a small, a medium and a large segment collection. Both collections -/// are sorted by different kernels. Otherwise, all segments are sorted by a single kernel. -/// \tparam EnableUnpartitionedWarpSort - If set to \p true, warp sort can be used to sort -/// the small segments, even if the total number of segments is below \p PartitioningThreshold. -/// \tparam LogicalWarpSizeMedium - number of threads in the logical warp of the kernel -/// that processes medium segments. -/// \tparam ItemsPerThreadMedium - number of items processed by a thread in the kernel that processes -/// medium segments. -/// \tparam BlockSizeMedium - number of threads per block in the kernel which processes the medium segments. -template -struct WarpSortConfig -{ - static_assert(LogicalWarpSizeSmall * ItemsPerThreadSmall - <= LogicalWarpSizeMedium * ItemsPerThreadMedium, - "The number of items processed by a small warp cannot be larger than the number " - "of items processed by a medium warp"); - /// \brief The number of threads in the logical warp in the small segment processing kernel. - static constexpr unsigned int logical_warp_size_small = LogicalWarpSizeSmall; - /// \brief The number of items processed by a thread in the small segment processing kernel. - static constexpr unsigned int items_per_thread_small = ItemsPerThreadSmall; - /// \brief The number of threads per block in the small segment processing kernel. - static constexpr unsigned int block_size_small = BlockSizeSmall; - /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into - /// small and large segment groups, and each group is handled by a different, specialized kernel. - static constexpr unsigned int partitioning_threshold = PartitioningThreshold; - /// \brief If set to \p true, warp sort can be used to sort the small segments, even if the total number of - /// segments is below \p PartitioningThreshold. - static constexpr bool enable_unpartitioned_warp_sort = EnableUnpartitionedWarpSort; - /// \brief The number of threads in the logical warp in the medium segment processing kernel. - static constexpr unsigned int logical_warp_size_medium = LogicalWarpSizeMedium; - /// \brief The number of items processed by a thread in the medium segment processing kernel. - static constexpr unsigned int items_per_thread_medium = ItemsPerThreadMedium; - /// \brief The number of threads per block in the medium segment processing kernel. - static constexpr unsigned int block_size_medium = BlockSizeMedium; -}; - -/// \brief Indicates if the warp level sorting is disabled in the -/// device segmented radix sort configuration. -struct DisabledWarpSortConfig -{ - /// \brief The number of threads in the logical warp in the small segment processing kernel. - static constexpr unsigned int logical_warp_size_small = 1; - /// \brief The number of items processed by a thread in the small segment processing kernel. - static constexpr unsigned int items_per_thread_small = 1; - /// \brief The number of threads per block in the small segment processing kernel. - static constexpr unsigned int block_size_small = 1; - /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into - /// small and large segment groups, and each group is handled by a different, specialized kernel. - static constexpr unsigned int partitioning_threshold = 0; - /// \brief If set to \p true, warp sort can be used to sort the small segments, even if the total number of - /// segments is below \p PartitioningThreshold. - static constexpr bool enable_unpartitioned_warp_sort = false; - /// \brief The number of threads in the logical warp in the medium segment processing kernel. - static constexpr unsigned int logical_warp_size_medium = 1; - /// \brief The number of items processed by a thread in the medium segment processing kernel. - static constexpr unsigned int items_per_thread_medium = 1; - /// \brief The number of threads per block in the medium segment processing kernel. - static constexpr unsigned int block_size_medium = 1; -}; - /// \brief Selects the appropriate \p WarpSortConfig based on the size of the key type. /// /// \tparam Key - the type of the sorted keys. @@ -123,235 +49,54 @@ using select_warp_sort_config_t 4, //< items per thread - small kernel 256, //< block size - small kernel 3000, //< partitioning threshold - (sizeof(Key) > 2), //< enable unpartitioned warp sort MediumWarpSize, //< logical warp size - medium kernel 4, //< items per thread - medium kernel 256 //< block size - medium kernel >>; -/// \brief Configuration of device-level segmented radix sort operation. -/// -/// Radix sort is excecuted in a few iterations (passes) depending on total number of bits to be sorted -/// (\p begin_bit and \p end_bit), each iteration sorts either \p LongRadixBits or \p ShortRadixBits bits -/// choosen to cover whole bit range in optimal way. -/// -/// For example, if \p LongRadixBits is 7, \p ShortRadixBits is 6, \p begin_bit is 0 and \p end_bit is 32 -/// there will be 5 iterations: 7 + 7 + 6 + 6 + 6 = 32 bits. -/// -/// If a segment's element count is low ( <= warp_sort_config::items_per_thread * warp_sort_config::logical_warp_size ), -/// it is sorted by a special warp-level sorting method. -/// -/// \tparam LongRadixBits - number of bits in long iterations. -/// \tparam ShortRadixBits - number of bits in short iterations, must be equal to or less than \p LongRadixBits. -/// \tparam SortConfig - configuration of radix sort kernel. Must be \p kernel_config. -/// \tparam WarpSortConfig - configuration of the warp sort that is used on the short segments. -template< - unsigned int LongRadixBits, - unsigned int ShortRadixBits, - class SortConfig, - class WarpSortConfig = DisabledWarpSortConfig -> -struct segmented_radix_sort_config -{ - /// \brief Number of bits in long iterations. - static constexpr unsigned int long_radix_bits = LongRadixBits; - /// \brief Number of bits in short iterations - static constexpr unsigned int short_radix_bits = ShortRadixBits; - /// \brief Configuration of radix sort kernel. - using sort = SortConfig; - /// \brief Configuration of the warp sort method. - using warp_sort_config = WarpSortConfig; -}; - namespace detail { -template -struct segmented_radix_sort_config_803 +template +struct wrapped_segmented_radix_sort_config { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<8, 7, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 13>, select_warp_sort_config_t > - >, - segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > - >; -}; - -template -struct segmented_radix_sort_config_803 - : select_type< - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > > - > { }; - -template -struct segmented_radix_sort_config_900 -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<4, 4, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > - >; -}; - -template -struct segmented_radix_sort_config_900 - : select_type< - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > > - > { }; - -template -struct segmented_radix_sort_config_90a -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<4, - 4, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<6, - 5, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, - 6, - kernel_config<256, 15>, - select_warp_sort_config_t>>, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, - 6, - kernel_config<256, 15>, - select_warp_sort_config_t>>, - segmented_radix_sort_config<7, - 6, - kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, - select_warp_sort_config_t>>; + static_assert(std::is_same::value, + "Config must be a specialization of struct template segmented_radix_sort_config"); + + template + struct architecture_config + { + static constexpr detail::segmented_radix_sort_config_params params + = SegmentedRadixSortConfig{}; + }; }; -template -struct segmented_radix_sort_config_90a - : select_type< - select_type_case< - sizeof(Key) == 1, - segmented_radix_sort_config<4, - 3, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - sizeof(Key) == 2, - segmented_radix_sort_config<6, - 5, - kernel_config<256, 10>, - select_warp_sort_config_t>>, - select_type_case< - sizeof(Key) == 4, - segmented_radix_sort_config<7, - 6, - kernel_config<256, 17>, - select_warp_sort_config_t>>, - select_type_case< - sizeof(Key) == 8, - segmented_radix_sort_config<7, - 6, - kernel_config<256, 15>, - select_warp_sort_config_t>>> -{}; - -template -struct segmented_radix_sort_config_1030 +template +struct wrapped_segmented_radix_sort_config { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = select_type< - select_type_case< - (sizeof(Key) == 1 && sizeof(Value) <= 8), - segmented_radix_sort_config<4, 4, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 2 && sizeof(Value) <= 8), - segmented_radix_sort_config<6, 5, kernel_config<256, 10>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 4 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - select_type_case< - (sizeof(Key) == 8 && sizeof(Value) <= 8), - segmented_radix_sort_config<7, 6, kernel_config<256, 15>, select_warp_sort_config_t > - >, - segmented_radix_sort_config<7, 6, kernel_config<256, ::rocprim::max(1u, 15u / item_scale)>, select_warp_sort_config_t > - >; + template + struct architecture_config + { + static constexpr segmented_radix_sort_config_params params + = detail::default_segmented_radix_sort_config(Arch), + key_type, + value_type>{}; + }; }; -template -struct segmented_radix_sort_config_1030 - : select_type< - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > >, - select_type_case, select_warp_sort_config_t > > - > { }; - -template -struct default_segmented_radix_sort_config - : select_arch< - TargetArch, - select_arch_case<803, detail::segmented_radix_sort_config_803>, - select_arch_case<900, detail::segmented_radix_sort_config_900>, - select_arch_case<906, detail::segmented_radix_sort_config_90a>, - select_arch_case<908, detail::segmented_radix_sort_config_90a>, - select_arch_case>, - select_arch_case<1030, detail::segmented_radix_sort_config_1030>, - detail::segmented_radix_sort_config_900> -{}; +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +template +constexpr segmented_radix_sort_config_params + wrapped_segmented_radix_sort_config:: + architecture_config::params; +template +template +constexpr segmented_radix_sort_config_params + wrapped_segmented_radix_sort_config:: + architecture_config::params; +#endif // DOXYGEN_SHOULD_SKIP_THIS } // end namespace detail diff --git a/rocprim/include/rocprim/device/device_segmented_reduce.hpp b/rocprim/include/rocprim/device/device_segmented_reduce.hpp index 424b291ec..d3e9a76a7 100644 --- a/rocprim/include/rocprim/device/device_segmented_reduce.hpp +++ b/rocprim/include/rocprim/device/device_segmented_reduce.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,12 +26,12 @@ #include #include "../config.hpp" -#include "../functional.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" +#include "../functional.hpp" #include "detail/config/device_reduce.hpp" #include "detail/device_segmented_reduce.hpp" +#include "rocprim/type_traits.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -100,9 +100,8 @@ hipError_t segmented_reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::detail::match_result_type< - input_type, BinaryFunction - >::type; + using result_type = + typename ::rocprim::invoke_result_binary_op::type; using config = wrapped_reduce_config; diff --git a/rocprim/include/rocprim/device/device_segmented_scan.hpp b/rocprim/include/rocprim/device/device_segmented_scan.hpp index 4cad6e0f4..e54564fbb 100644 --- a/rocprim/include/rocprim/device/device_segmented_scan.hpp +++ b/rocprim/include/rocprim/device/device_segmented_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,6 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" #include "../iterator/zip_iterator.hpp" #include "../iterator/discard_iterator.hpp" diff --git a/rocprim/include/rocprim/device/device_transform.hpp b/rocprim/include/rocprim/device/device_transform.hpp index a9de7d827..4238290a1 100644 --- a/rocprim/include/rocprim/device/device_transform.hpp +++ b/rocprim/include/rocprim/device/device_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,9 +28,8 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" -#include "../types/tuple.hpp" #include "../iterator/zip_iterator.hpp" +#include "../types/tuple.hpp" #include "device_transform_config.hpp" #include "detail/device_transform.hpp" @@ -142,7 +141,7 @@ inline hipError_t transform(InputIterator input, return hipSuccess; using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::detail::invoke_result::type; + using result_type = typename ::rocprim::invoke_result::type; using config = detail::wrapped_transform_config; diff --git a/rocprim/include/rocprim/device/device_transform_config.hpp b/rocprim/include/rocprim/device/device_transform_config.hpp index bf6c6c1a2..f350fe48a 100644 --- a/rocprim/include/rocprim/device/device_transform_config.hpp +++ b/rocprim/include/rocprim/device/device_transform_config.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ namespace detail // device transform does not have config tuning template -struct default_transform_config : default_transform_config_base +struct default_transform_config : default_transform_config_base::type {}; template diff --git a/rocprim/include/rocprim/intrinsics/warp.hpp b/rocprim/include/rocprim/intrinsics/warp.hpp index ba192af87..9c1d1eded 100644 --- a/rocprim/include/rocprim/intrinsics/warp.hpp +++ b/rocprim/include/rocprim/intrinsics/warp.hpp @@ -116,9 +116,6 @@ int warp_all(int predicate) } // end detail namespace -/// @} -// end of group intrinsicsmodule - /// \brief Group active lanes having the same bits of \p label /// /// Threads that have the same least significant \p LabelBits bits are grouped into the same group. @@ -185,6 +182,9 @@ ROCPRIM_DEVICE ROCPRIM_INLINE bool group_elect(lane_mask_type mask) return prev_same_count == 0 && mask != 0; } +/// @} +// end of group intrinsicsmodule + END_ROCPRIM_NAMESPACE #endif // ROCPRIM_INTRINSICS_WARP_HPP_ diff --git a/rocprim/include/rocprim/iterator/transform_iterator.hpp b/rocprim/include/rocprim/iterator/transform_iterator.hpp index db3d9a000..e98b8f03d 100644 --- a/rocprim/include/rocprim/iterator/transform_iterator.hpp +++ b/rocprim/include/rocprim/iterator/transform_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -26,7 +26,7 @@ #include #include "../config.hpp" -#include "../detail/match_result_type.hpp" +#include "../type_traits.hpp" /// \addtogroup iteratormodule /// @{ @@ -47,14 +47,11 @@ BEGIN_ROCPRIM_NAMESPACE /// \tparam UnaryFunction - type of the transform functor. /// \tparam ValueType - type of value that can be obtained by dereferencing the iterator. /// By default it is the return type of \p UnaryFunction. -template< - class InputIterator, - class UnaryFunction, - class ValueType = - typename ::rocprim::detail::invoke_result< - UnaryFunction, typename std::iterator_traits::value_type - >::type -> +template::value_type>::type> class transform_iterator { public: diff --git a/rocprim/include/rocprim/thread/thread_operators.hpp b/rocprim/include/rocprim/thread/thread_operators.hpp index 46f3d72a4..59ae5ac7b 100644 --- a/rocprim/include/rocprim/thread/thread_operators.hpp +++ b/rocprim/include/rocprim/thread/thread_operators.hpp @@ -1,7 +1,7 @@ /****************************************************************************** * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + * Modifications Copyright (c) 2017-2024, Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -175,7 +175,7 @@ namespace detail // rocPRIM (as well as Thrust) uses result type of BinaryFunction instead (if not void): // // using input_type = typename std::iterator_traits::value_type; -// using result_type = typename ::rocprim::detail::match_result_type< +// using result_type = typename ::rocprim::match_result_type< // input_type, BinaryFunction // >::type; // diff --git a/rocprim/include/rocprim/type_traits.hpp b/rocprim/include/rocprim/type_traits.hpp index c616decfe..8a40d79f3 100644 --- a/rocprim/include/rocprim/type_traits.hpp +++ b/rocprim/include/rocprim/type_traits.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,12 +21,11 @@ #ifndef ROCPRIM_TYPE_TRAITS_HPP_ #define ROCPRIM_TYPE_TRAITS_HPP_ -#include - -// Meta configuration for rocPRIM #include "config.hpp" #include "types.hpp" +#include + /// \addtogroup utilsmodule_typetraits /// @{ @@ -197,6 +196,117 @@ auto TwiddleOut(UnsignedBits key) }; #endif // DOXYGEN_SHOULD_SKIP_THIS +namespace detail +{ + +// invoke_result is based on std::invoke_result. +// The main difference is using ROCPRIM_HOST_DEVICE, this allows to +// use invoke_result with device-only lambdas/functors in host-only functions +// on HIP-clang. + +template +struct is_reference_wrapper : std::false_type +{}; +template +struct is_reference_wrapper> : std::true_type +{}; + +template +struct invoke_impl +{ + template + ROCPRIM_HOST_DEVICE static auto call(F&& f, Args&&... args) + -> decltype(std::forward(f)(std::forward(args)...)); +}; + +template +struct invoke_impl +{ + template::type, + class = typename std::enable_if::value>::type> + ROCPRIM_HOST_DEVICE static auto get(T&& t) -> T&&; + + template::type, + class = typename std::enable_if::value>::type> + ROCPRIM_HOST_DEVICE static auto get(T&& t) -> decltype(t.get()); + + template::type, + class = typename std::enable_if::value>::type, + class = typename std::enable_if::value>::type> + ROCPRIM_HOST_DEVICE static auto get(T&& t) -> decltype(*std::forward(t)); + + template::value>::type> + ROCPRIM_HOST_DEVICE static auto call(MT1 B::*pmf, T&& t, Args&&... args) + -> decltype((invoke_impl::get(std::forward(t)).*pmf)(std::forward(args)...)); + + template + ROCPRIM_HOST_DEVICE static auto call(MT B::*pmd, T&& t) + -> decltype(invoke_impl::get(std::forward(t)).*pmd); +}; + +template::type> +ROCPRIM_HOST_DEVICE auto INVOKE(F&& f, Args&&... args) + -> decltype(invoke_impl::call(std::forward(f), std::forward(args)...)); + +// Conforming C++14 implementation (is also a valid C++11 implementation): +template +struct invoke_result_impl +{}; +template +struct invoke_result_impl(), std::declval()...))), + F, + Args...> +{ + using type = decltype(INVOKE(std::declval(), std::declval()...)); +}; + +} // end namespace detail + +/// \brief Behaves like ``std::invoke_result``, but allows the use of invoke_result +/// with device-only lambdas/functors in host-only functions on HIP-clang. +/// +/// \tparam F Type of the function. +/// \tparam Args Input type(s) to the function ``F``. +template +struct invoke_result : detail::invoke_result_impl +{ +#ifdef DOXYGEN_DOCUMENTATION_BUILD + /// \brief The return type of the Callable type F if invoked with the arguments Args. + /// \hideinitializer + using type = detail::invoke_result_impl::type; +#endif // DOXYGEN_DOCUMENTATION_BUILD +}; + +/// \brief Helper type. It is an alias for ``invoke_result::type``. +/// +/// \tparam F Type of the function. +/// \tparam Args Input type(s) to the function ``F``. +template +using invoke_result_t = typename invoke_result::type; + +/// \brief Utility wrapper around ``invoke_result`` for binary operators. +/// +/// \tparam T Input type to the binary operator. +/// \tparam F Type of the binary operator. +template +struct invoke_result_binary_op +{ + /// \brief The result type of the binary operator. + using type = typename invoke_result::type; +}; + +/// \brief Helper type. It is an alias for ``invoke_result_binary_op::type``. +/// +/// \tparam T Input type to the binary operator. +/// \tparam F Type of the binary operator. +template +using invoke_result_binary_op_t = typename invoke_result_binary_op::type; END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp index 9ce2350ba..f53458818 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -155,9 +155,19 @@ class warp_scan_dpp } template - ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, T& reduction, - BinaryFunction scan_op) + ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + T input, T& output, storage_type& /*storage*/, T& reduction, BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize - 1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(output, output); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE void + exclusive_scan(T input, T& output, T init, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); // Broadcast value from the last thread in warp @@ -234,8 +244,8 @@ class warp_scan_dpp } protected: - ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) { (void) storage; return to_exclusive(inclusive_input, exclusive_output); diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp index bedc99b5b..47736b065 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -104,6 +104,15 @@ class warp_scan_shared_mem to_exclusive(output, storage); } + template + ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + T input, T& output, storage_type& storage, T& reduction, BinaryFunction scan_op) + { + inclusive_scan(input, output, storage, scan_op); + reduction = storage.get().threads[WarpSize - 1]; + to_exclusive(output, storage); + } + template ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan(T input, T& output, T init, T& reduction, @@ -158,8 +167,8 @@ class warp_scan_shared_mem } protected: - ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) { (void) inclusive_input; return to_exclusive(exclusive_output, storage); diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp index 457297528..106888c4a 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -118,6 +118,17 @@ class warp_scan_shuffle to_exclusive(output, output); } + template + ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + T input, T& output, storage_type& /*storage*/, T& reduction, BinaryFunction scan_op) + { + inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in warp + reduction = warp_shuffle(output, WarpSize - 1, WarpSize); + // Convert inclusive scan result to exclusive + to_exclusive(output, output); + } + template ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan(T input, T& output, T init, T& reduction, @@ -198,8 +209,8 @@ class warp_scan_shuffle } protected: - ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) { (void) storage; return to_exclusive(inclusive_input, exclusive_output); diff --git a/rocprim/include/rocprim/warp/warp_exchange.hpp b/rocprim/include/rocprim/warp/warp_exchange.hpp index 99581091b..6e54f9d3b 100644 --- a/rocprim/include/rocprim/warp/warp_exchange.hpp +++ b/rocprim/include/rocprim/warp/warp_exchange.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -210,8 +210,8 @@ class warp_exchange { const auto value = ::rocprim::warp_shuffle( input[src_idx], - flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread) - ); + flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread), + WarpSize); if(src_idx == flat_id % ItemsPerThread) { work_array[dst_idx] = value; @@ -328,10 +328,10 @@ class warp_exchange ROCPRIM_UNROLL for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) { - const auto value = ::rocprim::warp_shuffle( - input[src_idx], - (ItemsPerThread * flat_id + dst_idx) % WarpSize - ); + const auto value + = ::rocprim::warp_shuffle(input[src_idx], + (ItemsPerThread * flat_id + dst_idx) % WarpSize, + WarpSize); if(flat_id / (WarpSize / ItemsPerThread) == src_idx) { work_array[dst_idx] = value; diff --git a/rocprim/include/rocprim/warp/warp_scan.hpp b/rocprim/include/rocprim/warp/warp_scan.hpp index 116d10360..9e8738046 100644 --- a/rocprim/include/rocprim/warp/warp_scan.hpp +++ b/rocprim/include/rocprim/warp/warp_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -443,6 +443,80 @@ class warp_scan return; } + /// \brief Performs exclusive scan without an initial value across threads in a logical warp + /// \tparam BinaryFunction binary function used for scan + /// \param input Thread input value + /// \param[out] output Reference to thread output value. Each threads value for the scan will + /// be written to it. May be aliased with `input`. The value written is unspecified for the first + /// thread of each logical warp. + /// \param [in] storage Reference to a temporary storage object of type storage_type. + /// \param scan_op The function object used to combine elements used for the scan + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, + T& output, + storage_type& storage, + BinaryFunction scan_op = BinaryFunction()) +#ifndef DOXYGEN_DOCUMENTATION_BUILD + -> std::enable_if_t +#else + -> void +#endif + { + base_type::exclusive_scan(input, output, storage, scan_op); + } + + /// \cond + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T /*input*/, + T& /*output*/, + storage_type& /*storage*/, + BinaryFunction /*scan_op*/ = BinaryFunction()) + -> std::enable_if_t<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE)> + { + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size." + " Aborting warp scan."); + } + /// \endcond + + /// \brief Performs exclusive scan and reduction without an initial value across threads in + /// a logical warp + /// \tparam BinaryFunction binary function used for scan + /// \param input Thread input value + /// \param[out] output Reference to thread output value. Each threads value for the scan will + /// be written to it. May be aliased with `input`. The value written is unspecified for the first + /// thread of each logical warp. + /// \param[out] reduction Result of reducing of all `input` values in the logical warp. + /// \param [in] storage Reference to a temporary storage object of type storage_type. + /// \param scan_op The function object used to combine elements used for the scan + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, + T& output, + storage_type& storage, + T& reduction, + BinaryFunction scan_op = BinaryFunction()) +#ifndef DOXYGEN_DOCUMENTATION_BUILD + -> std::enable_if_t +#else + -> void +#endif + { + base_type::exclusive_scan(input, output, storage, reduction, scan_op); + } + + /// \cond + template, unsigned int FunctionWarpSize = WarpSize> + ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T /*input*/, + T& /*output*/, + storage_type& /*storage*/, + T& /*reduction*/, + BinaryFunction /*scan_op*/ = BinaryFunction()) + -> std::enable_if_t<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE)> + { + ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size." + " Aborting warp scan."); + } + /// \endcond + /// \brief Performs inclusive and exclusive scan operations across threads /// in a logical warp. /// @@ -658,19 +732,18 @@ class warp_scan #ifndef DOXYGEN_SHOULD_SKIP_THIS protected: - + // These undocumented functions are used by hipCUB prior to version 3.1 template - ROCPRIM_DEVICE ROCPRIM_INLINE - auto to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) - -> typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE auto + to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) -> + typename std::enable_if<(FunctionWarpSize <= __AMDGCN_WAVEFRONT_SIZE), void>::type { return base_type::to_exclusive(inclusive_input, exclusive_output, storage); } template - ROCPRIM_DEVICE ROCPRIM_INLINE - auto to_exclusive(T , T& , storage_type&) - -> typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type + [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE auto to_exclusive(T, T&, storage_type&) -> + typename std::enable_if<(FunctionWarpSize > __AMDGCN_WAVEFRONT_SIZE), void>::type { ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size. Aborting warp sort."); return; diff --git a/scripts/autotune/create_optimization.py b/scripts/autotune/create_optimization.py index 677e02f3e..57c2e0d66 100755 --- a/scripts/autotune/create_optimization.py +++ b/scripts/autotune/create_optimization.py @@ -475,6 +475,16 @@ class AlgorithmDeviceAdjacentDifferenceInplace(Algorithm): def __init__(self, fallback_entries): Algorithm.__init__(self, fallback_entries) +class AlgorithmDeviceSegmentedRadixSort(Algorithm): + algorithm_name = 'device_segmented_radix_sort' + cpp_configuration_template_name = 'segmented_radix_sort_config_template' + config_selection_params = [ + SelectionType(name='key_type', is_optional=False), + SelectionType(name='value_type', is_optional=True)] + + def __init__(self, fallback_entries): + Algorithm.__init__(self, fallback_entries) + def filt_algo_regex(e, algorithm_name): if 'algo_regex' in e: return re.match(e['algo_regex'], algorithm_name) is not None @@ -508,6 +518,8 @@ def create_algorithm(algorithm_name: str, fallback_entries): return AlgorithmDeviceAdjacentDifference(fallback_entries) elif algorithm_name == 'device_adjacent_difference_inplace': return AlgorithmDeviceAdjacentDifferenceInplace(fallback_entries) + elif algorithm_name == 'device_segmented_radix_sort': + return AlgorithmDeviceSegmentedRadixSort(fallback_entries) else: raise(NotSupportedError(f'Algorithm "{algorithm_name}" is not supported (yet)')) diff --git a/scripts/autotune/templates/adjacent_difference_config_template b/scripts/autotune/templates/adjacent_difference_config_template index f40ad24cd..353e86d1f 100644 --- a/scripts/autotune/templates/adjacent_difference_config_template +++ b/scripts/autotune/templates/adjacent_difference_config_template @@ -10,7 +10,7 @@ adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg'] {% macro general_case() -%} template -struct default_adjacent_difference_config : default_adjacent_difference_config_base +struct default_adjacent_difference_config : default_adjacent_difference_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/adjacent_difference_inplace_config_template b/scripts/autotune/templates/adjacent_difference_inplace_config_template index 1031bf5e0..3e5d7fa5c 100644 --- a/scripts/autotune/templates/adjacent_difference_inplace_config_template +++ b/scripts/autotune/templates/adjacent_difference_inplace_config_template @@ -10,7 +10,7 @@ adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg'] {% macro general_case() -%} template -struct default_adjacent_difference_inplace_config : default_adjacent_difference_config_base +struct default_adjacent_difference_inplace_config : default_adjacent_difference_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/histogram_config_template b/scripts/autotune/templates/histogram_config_template index 60aff0382..d492ab5c8 100644 --- a/scripts/autotune/templates/histogram_config_template +++ b/scripts/autotune/templates/histogram_config_template @@ -11,7 +11,7 @@ histogram_config struct default_histogram_config : -default_histogram_config_base { }; +default_histogram_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/reduce_config_template b/scripts/autotune/templates/reduce_config_template index 5630d58dc..3d3924cf3 100644 --- a/scripts/autotune/templates/reduce_config_template +++ b/scripts/autotune/templates/reduce_config_template @@ -10,7 +10,7 @@ reduce_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, : {% macro general_case() -%} template struct default_reduce_config : -default_reduce_config_base { }; +default_reduce_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/scan_config_template b/scripts/autotune/templates/scan_config_template index 02f4fafa5..6de5a83ad 100644 --- a/scripts/autotune/templates/scan_config_template +++ b/scripts/autotune/templates/scan_config_template @@ -10,7 +10,7 @@ scan_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::r {% macro general_case() -%} template struct default_scan_config : -default_scan_config_base { }; +default_scan_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/scanbykey_config_template b/scripts/autotune/templates/scanbykey_config_template index e17a89de7..72653bb61 100644 --- a/scripts/autotune/templates/scanbykey_config_template +++ b/scripts/autotune/templates/scanbykey_config_template @@ -10,7 +10,7 @@ scan_by_key_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] {% macro general_case() -%} template struct default_scan_by_key_config : -default_scan_by_key_config_base { }; +default_scan_by_key_config_base::type { }; {%- endmacro %} {% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} diff --git a/scripts/autotune/templates/segmented_radix_sort_config_template b/scripts/autotune/templates/segmented_radix_sort_config_template new file mode 100644 index 000000000..0187df79e --- /dev/null +++ b/scripts/autotune/templates/segmented_radix_sort_config_template @@ -0,0 +1,20 @@ +{% extends "config_template" %} + +{% macro get_header_guard() %} +ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ +{%- endmacro %} + +{% macro kernel_configuration(measurement) -%} +segmented_radix_sort_config<{{ measurement['cfg']['lrb'] }}, {{ measurement['cfg']['srb'] }}, {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, {{ measurement['cfg']['eupws'] }}, typename std::conditional<{{ measurement['cfg']['wsc']['pa'] }},WarpSortConfig<{{ measurement['cfg']['wsc']['lwss'] }},{{ measurement['cfg']['wsc']['ipts'] }},{{ measurement['cfg']['wsc']['bss'] }},{{ measurement['cfg']['wsc']['pt'] }},{{ measurement['cfg']['wsc']['lwsm'] }},{{ measurement['cfg']['wsc']['iptm'] }},{{ measurement['cfg']['wsc']['bsm'] }}>,DisabledWarpSortConfig>::type> { }; +{%- endmacro %} + +{% macro general_case() -%} +template +struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6, 4>::type +{}; +{%- endmacro %} + +{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} +// Based on {{ based_on_type }} +template struct default_segmented_radix_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{%- endmacro %} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b7cc9b83d..aed06265f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -54,7 +54,9 @@ function(add_hip_test TEST_NAME TEST_SOURCES) get_filename_component(TEST_TARGET ${TEST_MAIN_SOURCE} NAME_WE) add_executable(${TEST_TARGET} ${TEST_SOURCES}) - rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + if (ROCPRIM_INSTALL) + rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + endif() target_include_directories(${TEST_TARGET} SYSTEM BEFORE PUBLIC @@ -130,9 +132,11 @@ add_subdirectory(rocprim) add_hip_test("hipgraph.basic" hipgraph/test_hipgraph_basic.cpp) add_hip_test("hipgraph.algs" hipgraph/test_hipgraph_algs.cpp) -rocm_install( - FILES "${INSTALL_TEST_FILE}" - DESTINATION "${CMAKE_INSTALL_BINDIR}/${PROJECT_NAME}" - COMPONENT tests - RENAME "CTestTestfile.cmake" -) +if (ROCPRIM_INSTALL) + rocm_install( + FILES "${INSTALL_TEST_FILE}" + DESTINATION "${CMAKE_INSTALL_BINDIR}/${PROJECT_NAME}" + COMPONENT tests + RENAME "CTestTestfile.cmake" + ) +endif() diff --git a/test/common_test_header.hpp b/test/common_test_header.hpp index 8ac8fdf50..62f6afcfd 100755 --- a/test/common_test_header.hpp +++ b/test/common_test_header.hpp @@ -60,6 +60,12 @@ } #endif +#if(defined(__GNUC__) || defined(__clang__)) && (defined(__GLIBCXX__) || defined(_LIBCPP_VERSION)) + #define ROCPRIM_HAS_INT128_SUPPORT 1 +#else + #define ROCPRIM_HAS_INT128_SUPPORT 0 +#endif + #define INSTANTIATE_TYPED_TEST_EXPANDED_1(line, test_suite_name, ...) \ namespace Id##line \ { \ @@ -110,19 +116,27 @@ inline int obtain_device_from_ctest() #endif } +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning( \ + disable : 4996) // This function or variable may be unsafe. Consider using _dupenv_s instead. +#endif inline bool use_hmm() { - if (getenv("ROCPRIM_USE_HMM") == nullptr) + if(getenv("ROCPRIM_USE_HMM") == nullptr) { return false; } - if (strcmp(getenv("ROCPRIM_USE_HMM"), "1") == 0) + if(strcmp(getenv("ROCPRIM_USE_HMM"), "1") == 0) { return true; } return false; } +#ifdef _MSC_VER + #pragma warning(pop) +#endif // Helper for HMM allocations: HMM is requested through ROCPRIM_USE_HMM=1 environment variable template diff --git a/test/hipgraph/test_hipgraph_algs.cpp b/test/hipgraph/test_hipgraph_algs.cpp index a0811fd89..816cff1f4 100644 --- a/test/hipgraph/test_hipgraph_algs.cpp +++ b/test/hipgraph/test_hipgraph_algs.cpp @@ -52,7 +52,9 @@ void generate_needles(const std::vector& input, std::vector& o output[indices.size() + i] = out_of_bounds_vals[i]; // Mix up the in-bounds and out-of-bounds values to make the test a bit more robust - std::random_shuffle(output.begin(), output.end()); + std::random_device rd; + std::default_random_engine gen(rd()); + std::shuffle(output.begin(), output.end(), gen); } template diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index 3fbc675ce..6ab37e809 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,7 +36,9 @@ endfunction() function(add_rocprim_test_internal TEST_NAME TEST_SOURCES TEST_TARGET) add_executable(${TEST_TARGET} ${TEST_SOURCES}) - rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + if (ROCPRIM_INSTALL) + rocm_install(TARGETS ${TEST_TARGET} COMPONENT tests) + endif() target_include_directories(${TEST_TARGET} SYSTEM BEFORE PUBLIC @@ -261,13 +263,14 @@ add_rocprim_test("rocprim.device_transform" test_device_transform.cpp) add_rocprim_test("rocprim.discard_iterator" test_discard_iterator.cpp) add_rocprim_test("rocprim.reverse_iterator" test_reverse_iterator.cpp) if(NOT USE_HIP_CPU) - add_rocprim_test("rocprim.texture_cache_iterator" test_texture_cache_iterator.cpp) +add_rocprim_test("rocprim.texture_cache_iterator" test_texture_cache_iterator.cpp) endif() add_rocprim_test("rocprim.thread" test_thread.cpp) add_rocprim_test("rocprim.thread_algos" test_thread_algos.cpp) add_rocprim_test("rocprim.transform_iterator" test_transform_iterator.cpp) add_rocprim_test("rocprim.no_half_operators" test_no_half_operators.cpp) add_rocprim_test("rocprim.intrinsics" test_intrinsics.cpp) +add_rocprim_test("rocprim.invoke_result" test_invoke_result.cpp) add_rocprim_test("rocprim.warp_exchange" test_warp_exchange.cpp) add_rocprim_test("rocprim.warp_load" test_warp_load.cpp) add_rocprim_test("rocprim.warp_reduce" test_warp_reduce.cpp) diff --git a/test/rocprim/indirect_iterator.hpp b/test/rocprim/indirect_iterator.hpp new file mode 100644 index 000000000..68328f7bb --- /dev/null +++ b/test/rocprim/indirect_iterator.hpp @@ -0,0 +1,179 @@ +// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef TEST_INDIRECT_ITERATOR_HPP_ +#define TEST_INDIRECT_ITERATOR_HPP_ + +#include +#include + +#include + +namespace test_utils +{ + +// assign-through reference_wrapper implementation +template +class reference_wrapper +{ +public: + // types + using type = T; + + // construct/copy/destroy + explicit constexpr reference_wrapper(T& t) : _ptr(&t) {} + + constexpr reference_wrapper(const reference_wrapper&) noexcept = default; + + // assignment + constexpr reference_wrapper& operator=(const T& x) noexcept + { + *_ptr = x; + return *this; + } + + // access + constexpr operator T&() const noexcept + { + return *_ptr; + } + constexpr T& get() const noexcept + { + return *_ptr; + } + +private: + T* _ptr; +}; + +// Iterator used in tests to check situtations when value_type of the +// iterator is not the same as the return type of operator[]. +// It is a simplified version of device_vector::iterator from thrust. +template +class indirect_iterator +{ +public: + // Iterator traits + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T*; + using reference = reference_wrapper; + + using iterator_category = std::random_access_iterator_tag; + + ROCPRIM_HOST_DEVICE inline indirect_iterator(T* ptr) : ptr_(ptr) {} + + ROCPRIM_HOST_DEVICE inline ~indirect_iterator() = default; + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator++() + { + ++ptr_; + return *this; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator++(int) + { + indirect_iterator old = *this; + ++ptr_; + return old; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator--() + { + --ptr_; + return *this; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator--(int) + { + indirect_iterator old = *this; + --ptr_; + return old; + } + + ROCPRIM_HOST_DEVICE inline reference operator*() const + { + return *ptr_; + } + + ROCPRIM_HOST_DEVICE inline reference operator[](difference_type n) const + { + return reference{*(ptr_ + n)}; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator+(difference_type distance) const + { + auto i = ptr_ + distance; + return indirect_iterator{i}; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator+=(difference_type distance) + { + ptr_ += distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator operator-(difference_type distance) const + { + auto i = ptr_ - distance; + return indirect_iterator{i}; + } + + ROCPRIM_HOST_DEVICE inline indirect_iterator& operator-=(difference_type distance) + { + ptr_ -= distance; + return *this; + } + + ROCPRIM_HOST_DEVICE inline difference_type operator-(indirect_iterator other) const + { + return ptr_ - other.ptr_; + } + + ROCPRIM_HOST_DEVICE inline bool operator==(indirect_iterator other) const + { + return ptr_ == other.ptr_; + } + + ROCPRIM_HOST_DEVICE inline bool operator!=(indirect_iterator other) const + { + return ptr_ != other.ptr_; + } + +private: + T* ptr_; +}; + +template +inline auto wrap_in_indirect_iterator(T* ptr) -> + typename std::enable_if>::type +{ + return indirect_iterator(ptr); +} + +template +inline auto wrap_in_indirect_iterator(T* ptr) -> typename std::enable_if::type +{ + return ptr; +} + +} // namespace test_utils + +#endif // TEST_INDIRECT_ITERATOR_HPP_ diff --git a/test/rocprim/test_device_adjacent_difference.cpp b/test/rocprim/test_device_adjacent_difference.cpp index 197a170f7..fa97eecb6 100644 --- a/test/rocprim/test_device_adjacent_difference.cpp +++ b/test/rocprim/test_device_adjacent_difference.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ #include "../common_test_header.hpp" +#include "indirect_iterator.hpp" #include "test_utils_types.hpp" #include @@ -37,70 +38,131 @@ namespace { -template -auto dispatch_adjacent_difference(std::true_type /*left*/, - std::false_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt output, - Args&&... args) +enum class api_variant +{ + no_alias, + alias, + in_place +}; + +std::string to_string(api_variant aliasing) +{ + switch(aliasing) + { + case api_variant::no_alias: return "no_alias"; + case api_variant::alias: return "alias"; + case api_variant::in_place: return "in_place"; + } +} + +template +auto dispatch_adjacent_difference( + std::true_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) { return ::rocprim::adjacent_difference( temporary_storage, storage_size, input, output, std::forward(args)...); } -template -auto dispatch_adjacent_difference(std::false_type /*left*/, - std::false_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt output, - Args&&... args) +template +auto dispatch_adjacent_difference( + std::false_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) { return ::rocprim::adjacent_difference_right( temporary_storage, storage_size, input, output, std::forward(args)...); } -template -auto dispatch_adjacent_difference(std::true_type /*left*/, - std::true_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt /*output*/, - Args&&... args) +template +auto dispatch_adjacent_difference( + std::true_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt /*output*/, + Args&&... args) { return ::rocprim::adjacent_difference_inplace( temporary_storage, storage_size, input, std::forward(args)...); } -template -auto dispatch_adjacent_difference(std::false_type /*left*/, - std::true_type /*in_place*/, - void* const temporary_storage, - std::size_t& storage_size, - const InputIt input, - const OutputIt /*output*/, - Args&&... args) +template +auto dispatch_adjacent_difference( + std::false_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt /*output*/, + Args&&... args) { return ::rocprim::adjacent_difference_right_inplace( temporary_storage, storage_size, input, std::forward(args)...); } +template +auto dispatch_adjacent_difference( + std::true_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) +{ + return ::rocprim::adjacent_difference_inplace(temporary_storage, + storage_size, + input, + output, + std::forward(args)...); +} + +template +auto dispatch_adjacent_difference( + std::false_type /*left*/, + std::integral_constant /*aliasing*/, + void* const temporary_storage, + std::size_t& storage_size, + const InputIt input, + const OutputIt output, + Args&&... args) +{ + return ::rocprim::adjacent_difference_right_inplace(temporary_storage, + storage_size, + input, + output, + std::forward(args)...); +} + template auto get_expected_result(const std::vector& input, const BinaryFunction op, @@ -125,35 +187,38 @@ auto get_expected_result(const std::vector& input, // Params for tests template + class OutputType = InputType, + bool Left = true, + api_variant Aliasing = api_variant::no_alias, + bool UseIdentityIterator = false, + class Config = rocprim::default_config, + bool UseGraphs = false, + bool UseIndirectIterator = false> struct DeviceAdjacentDifferenceParams { - using input_type = InputType; - using output_type = OutputType; - static constexpr bool left = Left; - static constexpr bool in_place = InPlace; - static constexpr bool use_identity_iterator = UseIdentityIterator; - using config = Config; - static constexpr bool use_graphs = UseGraphs; + using input_type = InputType; + using output_type = OutputType; + static constexpr bool left = Left; + static constexpr api_variant aliasing = Aliasing; + static constexpr bool use_identity_iterator = UseIdentityIterator; + using config = Config; + static constexpr bool use_graphs = UseGraphs; + static constexpr bool use_indirect_iterator = UseIndirectIterator; }; template class RocprimDeviceAdjacentDifferenceTests : public ::testing::Test { public: - using input_type = typename Params::input_type; - using output_type = typename Params::output_type; - static constexpr bool left = Params::left; - static constexpr bool in_place = Params::in_place; - static constexpr bool use_identity_iterator = Params::use_identity_iterator; - static constexpr bool debug_synchronous = false; - using config = typename Params::config; - static constexpr bool use_graphs = Params::use_graphs; + using input_type = typename Params::input_type; + using output_type = typename Params::output_type; + static constexpr bool left = Params::left; + static constexpr api_variant aliasing = Params::aliasing; + static constexpr bool use_identity_iterator = Params::use_identity_iterator; + static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; + static constexpr bool debug_synchronous = false; + using config = typename Params::config; + static constexpr bool use_graphs = Params::use_graphs; }; using custom_double2 = test_utils::custom_test_type; @@ -173,19 +238,60 @@ using RocprimDeviceAdjacentDifferenceTestsParams = ::testing::Types< // Tests with default configuration DeviceAdjacentDifferenceParams, DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + // this is changed to not use identity iterator + // because the function doesn't work with it, should be changed back, when fixed + DeviceAdjacentDifferenceParams, // Tests for supported config structs - DeviceAdjacentDifferenceParams, - DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, + DeviceAdjacentDifferenceParams, // Tests for different size_limits - DeviceAdjacentDifferenceParams>, - DeviceAdjacentDifferenceParams>, - DeviceAdjacentDifferenceParams>, - DeviceAdjacentDifferenceParams>; + DeviceAdjacentDifferenceParams>, + DeviceAdjacentDifferenceParams, + false, + true>, + DeviceAdjacentDifferenceParams, + false, + true>, + DeviceAdjacentDifferenceParams>; TYPED_TEST_SUITE(RocprimDeviceAdjacentDifferenceTests, RocprimDeviceAdjacentDifferenceTestsParams); @@ -195,15 +301,16 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - using T = typename TestFixture::input_type; - using output_type = typename TestFixture::output_type; - static constexpr bool left = TestFixture::left; - static constexpr bool in_place = TestFixture::in_place; - static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; - static constexpr bool debug_synchronous = TestFixture::debug_synchronous; - using Config = typename TestFixture::config; + using T = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + static constexpr bool left = TestFixture::left; + static constexpr api_variant aliasing = TestFixture::aliasing; + const bool debug_synchronous = TestFixture::debug_synchronous; + static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; + using Config = typename TestFixture::config; - SCOPED_TRACE(testing::Message() << "left = " << left << ", in_place = " << in_place); + SCOPED_TRACE(testing::Message() + << "left = " << left << ", api_variant = " << to_string(aliasing)); for(std::size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -223,123 +330,149 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - const std::vector input = test_utils::get_random_data(size, 1, 100, seed_value); - std::vector output(input.size()); + std::vector input = test_utils::get_random_data(size, 1, 100, seed_value); - T* d_input; - output_type* d_output = nullptr; + T* d_input; HIP_CHECK( test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(input[0]))); HIP_CHECK(hipMemcpy( d_input, input.data(), input.size() * sizeof(input[0]), hipMemcpyHostToDevice)); - if(!in_place) - { - HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, - output.size() * sizeof(output[0]))); - } + static constexpr auto left_tag = rocprim::detail::bool_constant{}; + static constexpr auto alias_tag = std::integral_constant{}; - static constexpr auto left_tag = rocprim::detail::bool_constant {}; - static constexpr auto in_place_tag = rocprim::detail::bool_constant {}; + auto input_it + = test_utils::wrap_in_indirect_iterator( + d_input); - // Calculate expected results on host - const auto expected - = get_expected_result(input, rocprim::minus<> {}, left_tag); - - const auto output_it - = test_utils::wrap_in_identity_iterator(d_output); - - hipGraph_t graph; + hipGraph_t graph; hipGraphExec_t graph_instance; - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) graph = test_utils::createGraphHelper(stream); - + // Allocate temporary storage std::size_t temp_storage_size; void* d_temp_storage = nullptr; HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, + alias_tag, d_temp_storage, temp_storage_size, - d_input, - output_it, + input_it, + (output_type*){nullptr}, size, - rocprim::minus<> {}, + rocprim::minus<>{}, stream, - TestFixture::debug_synchronous)); + debug_synchronous)); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); - + ASSERT_GT(temp_storage_size, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size)); - if (TestFixture::use_graphs) - test_utils::resetGraphHelper(graph, graph_instance, stream); - - // Run - HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, - d_temp_storage, - temp_storage_size, - d_input, - output_it, - size, - rocprim::minus<> {}, - stream, - TestFixture::debug_synchronous)); - HIP_CHECK(hipGetLastError()); + // We might call the API multiple times, with almost the same parameter + // (in-place and out-of-place) + // we should be able to use the same amount of temp storage for and get the same + // results (maybe with different types) for both. + auto run_and_verify = [&](const auto output_it, auto* d_output) + { + // Run + HIP_CHECK(dispatch_adjacent_difference(left_tag, + alias_tag, + d_temp_storage, + temp_storage_size, + input_it, + output_it, + size, + rocprim::minus<>{}, + stream, + debug_synchronous)); + HIP_CHECK(hipGetLastError()); + + if(TestFixture::use_graphs) + { + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + } + + // input_type for in-place, output_type for out of place + using current_output_type = std::remove_reference_t; + + // allocate memory for output + std::vector output(size); + + // Copy output to host + HIP_CHECK(hipMemcpy(output.data(), + d_output, + output.size() * sizeof(output[0]), + hipMemcpyDeviceToHost)); + + // Calculate expected results on host + const auto expected + = get_expected_result(input, rocprim::minus<>{}, left_tag); + + // Check if output values are as expected + test_utils::assert_near( + output, + expected, + std::max(test_utils::precision, test_utils::precision)); + }; + + // if api_variant is not in_place we should check the non aliased function call + if(aliasing != api_variant::in_place) + { + output_type* d_output = nullptr; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, size * sizeof(*d_output))); - if (TestFixture::use_graphs) - graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + if(TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); - // Copy output to host - HIP_CHECK( - hipMemcpy(output.data(), - in_place ? static_cast(d_input) : static_cast(d_output), - output.size() * sizeof(output[0]), - hipMemcpyDeviceToHost)); + const auto output_it + = test_utils::wrap_in_identity_iterator(d_output); - // Check if output values are as expected - ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( - output, - expected, - std::max(test_utils::precision, test_utils::precision))); + ASSERT_NO_FATAL_FAILURE(run_and_verify(output_it, d_output)); - hipFree(d_input); - if(!in_place) - { hipFree(d_output); } - hipFree(d_temp_storage); - if (TestFixture::use_graphs) + // if api_variant is not no_alias we should check the inplace function call + if(aliasing != api_variant::no_alias) + { + if(TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + + ASSERT_NO_FATAL_FAILURE(run_and_verify(input_it, d_input)); + } + + if(TestFixture::use_graphs) { test_utils::cleanupGraphHelper(graph, graph_instance); HIP_CHECK(hipStreamDestroy(stream)); } + + hipFree(d_temp_storage); + hipFree(d_input); } } } // Params for tests -template +template struct DeviceAdjacentDifferenceLargeParams { - static constexpr bool left = Left; - static constexpr bool in_place = InPlace; - static constexpr bool use_graphs = UseGraphs; + static constexpr bool left = Left; + static constexpr api_variant aliasing = Aliasing; + static constexpr bool use_graphs = UseGraphs; }; template class RocprimDeviceAdjacentDifferenceLargeTests : public ::testing::Test { public: - static constexpr bool left = Params::left; - static constexpr bool in_place = Params::in_place; - static constexpr bool debug_synchronous = false; - static constexpr bool use_graphs = Params::use_graphs; + static constexpr bool left = Params::left; + static constexpr api_variant aliasing = Params::aliasing; + static constexpr bool debug_synchronous = false; + static constexpr bool use_graphs = Params::use_graphs; }; template @@ -451,9 +584,10 @@ class check_output_iterator }; using RocprimDeviceAdjacentDifferenceLargeTestsParams - = ::testing::Types, - DeviceAdjacentDifferenceLargeParams, - DeviceAdjacentDifferenceLargeParams>; + = ::testing::Types, + DeviceAdjacentDifferenceLargeParams, + DeviceAdjacentDifferenceLargeParams, + DeviceAdjacentDifferenceLargeParams>; TYPED_TEST_SUITE(RocprimDeviceAdjacentDifferenceLargeTests, RocprimDeviceAdjacentDifferenceLargeTestsParams); @@ -467,13 +601,14 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) using T = size_t; static constexpr bool is_left = TestFixture::left; - static constexpr bool is_in_place = TestFixture::in_place; + static constexpr api_variant aliasing = TestFixture::aliasing; + const bool debug_synchronous = TestFixture::debug_synchronous; static constexpr unsigned int sampling_rate = 10000; using OutputIterator = check_output_iterator; using flag_type = OutputIterator::flag_type; SCOPED_TRACE(testing::Message() - << "is_left = " << is_left << ", is_in_place = " << is_in_place); + << "is_left = " << is_left << ", api_variant = " << to_string(aliasing)); hipStream_t stream = 0; // default if (TestFixture::use_graphs) @@ -514,7 +649,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) { return (smaller_value + larger_value) / 2 + (is_left ? 1 : 0); }; static constexpr auto left_tag = rocprim::detail::bool_constant{}; - static constexpr auto in_place_tag = rocprim::detail::bool_constant{}; + static constexpr auto aliasing_tag = std::integral_constant{}; hipGraph_t graph; hipGraphExec_t graph_instance; @@ -525,7 +660,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) std::size_t temp_storage_size; void* d_temp_storage = nullptr; HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, + aliasing_tag, d_temp_storage, temp_storage_size, input, @@ -533,7 +668,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) size, op, stream, - TestFixture::debug_synchronous)); + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -547,7 +682,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) // Run HIP_CHECK(dispatch_adjacent_difference(left_tag, - in_place_tag, + aliasing_tag, d_temp_storage, temp_storage_size, input, @@ -555,7 +690,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) size, op, stream, - TestFixture::debug_synchronous)); + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); diff --git a/test/rocprim/test_device_radix_sort.cpp.in b/test/rocprim/test_device_radix_sort.cpp.in index 4d0931246..65d34f6a6 100644 --- a/test/rocprim/test_device_radix_sort.cpp.in +++ b/test/rocprim/test_device_radix_sort.cpp.in @@ -50,7 +50,7 @@ #endif #if ROCPRIM_TEST_TYPE_SLICE == 0 -#if defined(__GNUC__) || defined(__clang__) +#if ROCPRIM_HAS_INT128_SUPPORT INSTANTIATE(params<__int128_t, __int128_t>) INSTANTIATE(params<__uint128_t, __uint128_t>) #endif diff --git a/test/rocprim/test_device_scan.cpp b/test/rocprim/test_device_scan.cpp index 8511ec6d6..9f2bdac22 100644 --- a/test/rocprim/test_device_scan.cpp +++ b/test/rocprim/test_device_scan.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -150,12 +150,12 @@ TYPED_TEST_SUITE(RocprimDeviceScanTests, RocprimDeviceScanTestsParams); TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) { - using T = typename TestFixture::input_type; - using U = typename TestFixture::output_type; + using T = typename TestFixture::input_type; + using U = typename TestFixture::output_type; using scan_op_type = typename TestFixture::scan_op_type; // if scan_op_type is rocprim::plus and input_type is bfloat16 or half, // use float as device-side accumulator and double as host-side accumulator - using acc_type = typename accum_type::type; + using acc_type = typename accum_type::type; const bool debug_synchronous = TestFixture::debug_synchronous; int device_id = test_common_utils::obtain_device_from_ctest(); @@ -182,9 +182,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) // scan function scan_op_type scan_op; - auto input_iterator = rocprim::make_transform_iterator( - rocprim::make_constant_iterator(T(345)), - [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(rocprim::make_constant_iterator(T(345)), + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -317,8 +317,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) expected.begin(), scan_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -329,13 +330,15 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::inclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - input.size(), scan_op, stream, TestFixture::debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -351,13 +354,15 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::inclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - input.size(), scan_op, stream, TestFixture::debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -474,8 +479,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) scan_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -486,13 +492,16 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - initial_value, input.size(), scan_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + initial_value, + input.size(), + scan_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -508,13 +517,16 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, - test_utils::wrap_in_identity_iterator(d_output), - initial_value, input.size(), scan_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan( + d_temp_storage, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output), + initial_value, + input.size(), + scan_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -646,8 +658,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) scan_op, keys_compare_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -658,12 +671,16 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::inclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -679,12 +696,16 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::inclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::inclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -820,8 +841,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) scan_op, keys_compare_op ); - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -832,12 +854,17 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::exclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, initial_value, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + initial_value, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); @@ -853,12 +880,17 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) test_utils::resetGraphHelper(graph, graph_instance, stream); // Run - HIP_CHECK( - rocprim::exclusive_scan_by_key( - d_temp_storage, temp_storage_size_bytes, d_keys, input_iterator, - d_output, initial_value, input.size(), scan_op, keys_compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::exclusive_scan_by_key(d_temp_storage, + temp_storage_size_bytes, + d_keys, + input_iterator, + d_output, + initial_value, + input.size(), + scan_op, + keys_compare_op, + stream, + debug_synchronous)); if (TestFixture::use_graphs) graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); @@ -1630,8 +1662,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) = rocprim::future_value>{ future_iter}; - auto input_iterator = rocprim::make_transform_iterator( - d_input, [] (T in) { return static_cast(in); }); + auto input_iterator + = rocprim::make_transform_iterator(d_input, + [](T in) { return static_cast(in); }); hipGraph_t graph; hipGraphExec_t graph_instance; @@ -1643,7 +1676,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) char* d_temp_storage = nullptr; // Get size of d_temp_storage HIP_CHECK(rocprim::exclusive_scan( - nullptr, temp_storage_size_bytes, input_iterator, + nullptr, + temp_storage_size_bytes, + input_iterator, test_utils::wrap_in_identity_iterator(d_output), future_initial_value, input.size(), @@ -1686,7 +1721,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) // Run HIP_CHECK(rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, input_iterator, + d_temp_storage, + temp_storage_size_bytes, + input_iterator, test_utils::wrap_in_identity_iterator(d_output), future_initial_value, input.size(), diff --git a/test/rocprim/test_device_segmented_radix_sort.hpp b/test/rocprim/test_device_segmented_radix_sort.hpp index d75f068ed..34e462a62 100644 --- a/test/rocprim/test_device_segmented_radix_sort.hpp +++ b/test/rocprim/test_device_segmented_radix_sort.hpp @@ -53,44 +53,47 @@ struct params using config = Config; }; -using config_default = rocprim::segmented_radix_sort_config< - 4, //< long radix bits - 3, //< short radix bits - rocprim::kernel_config<256, 4> //< sort block size, items per thread - >; - -using config_semi_custom = rocprim::segmented_radix_sort_config< - 3, //< long radix bits - 2, //< short radix bits - rocprim::kernel_config<128, 4>, //< sort block size, items per thread - rocprim::WarpSortConfig<16, //< logical warp size small - 8 //< items per thread small - >>; - -using config_semi_custom_warp_config = rocprim::segmented_radix_sort_config< - 3, //< long radix bits - 2, //< short radix bits - rocprim::kernel_config<128, 4>, //< sort block size, items per thread - rocprim::WarpSortConfig<16, //< logical warp size small - 2, //< items per thread small - 512, //< block size small - 0, //< partitioning threshold - true //< enable unpartitioned sort - >>; - -using config_custom = rocprim::segmented_radix_sort_config< - 3, //< long radix bits - 2, //< short radix bits - rocprim::kernel_config<128, 4>, //< sort block size, items per thread - rocprim::WarpSortConfig<16, //< logical warp size small - 2, //< items per thread small - 512, //< block size small - 0, //< partitioning threshold - true, //< enable unpartitioned sort - 32, //< logical warp size medium - 4, //< items per thread medium - 256 //< block size medium - >>; +using config_default = rocprim::segmented_radix_sort_config<4, //< long radix bits + 3, //< short radix bits + 256, //< sort block size, + 4 //< items per thread + >; + +using config_semi_custom + = rocprim::segmented_radix_sort_config<3, //< long radix bits + 2, //< short radix bits + 128, //< sort block size + 4, //< items per thread + false, //< enable unpartitioned sort + rocprim::WarpSortConfig<16, //< logical warp size small + 8 //< items per thread small + >>; + +using config_semi_custom_warp_config + = rocprim::segmented_radix_sort_config<3, //< long radix bits + 2, //< short radix bits + 128, //< sort block size + 4, //< items per thread + true, //< enable unpartitioned sort + rocprim::WarpSortConfig<16, //< logical warp size small + 2, //< items per thread small + 512, //< block size small + 0>>; //< partitioning threshold + +using config_custom + = rocprim::segmented_radix_sort_config<3, //< long radix bits + 2, //< short radix bits + 128, //< sort block size + 4, //< items per thread + true, //< enable unpartitioned sort + rocprim::WarpSortConfig<16, //< logical warp size small + 2, //< items per thread small + 512, //< block size small + 0, //< partitioning threshold + 32, //< logical warp size medium + 4, //< items per thread medium + 256 //< block size medium + >>; template class RocprimDeviceSegmentedRadixSort : public ::testing::Test diff --git a/test/rocprim/test_device_select.cpp b/test/rocprim/test_device_select.cpp index b903d5b4d..88d5562b4 100644 --- a/test/rocprim/test_device_select.cpp +++ b/test/rocprim/test_device_select.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -977,7 +977,7 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) if (TestFixture::use_graphs) test_utils::resetGraphHelper(graph, graph_instance, stream); - + // Run HIP_CHECK( rocprim::unique_by_key( @@ -997,7 +997,7 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) if (TestFixture::use_graphs) graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); - + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -1050,6 +1050,200 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) HIP_CHECK(hipStreamDestroy(stream)); } +TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + // This test checks correctness of in-place unique_by_key (so input keys and values iterators + // are passed as output iterators as well) + using key_type = typename TestFixture::key_type; + using value_type = typename TestFixture::value_type; + using output_key_type = key_type; + using output_value_type = value_type; + + using op_type = rocprim::equal_to; + + using scan_op_type = rocprim::plus; + static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; + const bool debug_synchronous = TestFixture::debug_synchronous; + + hipStream_t stream = 0; // default stream + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + const auto probabilities = get_discontinuity_probabilities(); + for(auto size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + for(auto p : probabilities) + { + SCOPED_TRACE(testing::Message() << "with p = " << p); + + // Generate data + std::vector input_keys(size); + { + std::vector input01 + = test_utils::get_random_data01(size, p, seed_value); + std::partial_sum(input01.begin(), + input01.end(), + input_keys.begin(), + scan_op_type()); + } + const auto input_values + = test_utils::get_random_data(size, -1000, 1000, seed_value); + + // Allocate and copy to device + key_type* d_keys_input; + value_type* d_values_input; + unsigned int* d_selected_count_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_keys_input, + input_keys.size() * sizeof(input_keys[0]))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_values_input, + input_values.size() + * sizeof(input_values[0]))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_selected_count_output, + sizeof(unsigned int))); + HIP_CHECK(hipMemcpy(d_keys_input, + input_keys.data(), + input_keys.size() * sizeof(input_keys[0]), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_values_input, + input_values.data(), + input_values.size() * sizeof(input_values[0]), + hipMemcpyHostToDevice)); + HIP_CHECK(hipDeviceSynchronize()); + + // Calculate expected results on host + std::vector expected_keys; + std::vector expected_values; + expected_keys.reserve(input_keys.size()); + expected_values.reserve(input_values.size()); + if(size > 0) + { + expected_keys.push_back(input_keys[0]); + expected_values.push_back(input_values[0]); + for(size_t i = 1; i < input_keys.size(); i++) + { + if(!op_type()(input_keys[i - 1], input_keys[i])) + { + expected_keys.push_back(input_keys[i]); + expected_values.push_back(input_values[i]); + } + } + } + + hipGraph_t graph; + hipGraphExec_t graph_instance; + if(TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + + // temp storage + size_t temp_storage_size_bytes; + // Get size of d_temp_storage + HIP_CHECK(rocprim::unique_by_key( + nullptr, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + test_utils::wrap_in_identity_iterator(d_keys_input), + test_utils::wrap_in_identity_iterator(d_values_input), + d_selected_count_output, + input_keys.size(), + op_type(), + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + graph_instance = graph_instance + = test_utils::endCaptureGraphHelper(graph, stream, true, false); + + HIP_CHECK(hipDeviceSynchronize()); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + void* d_temp_storage = nullptr; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + if(TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + + // Run + HIP_CHECK(rocprim::unique_by_key( + d_temp_storage, + temp_storage_size_bytes, + d_keys_input, + d_values_input, + test_utils::wrap_in_identity_iterator(d_keys_input), + test_utils::wrap_in_identity_iterator(d_values_input), + d_selected_count_output, + input_keys.size(), + op_type(), + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + graph_instance = graph_instance + = test_utils::endCaptureGraphHelper(graph, stream, true, false); + + HIP_CHECK(hipDeviceSynchronize()); + + // Check if number of selected value is as expected + unsigned int selected_count_output = 0; + HIP_CHECK(hipMemcpy(&selected_count_output, + d_selected_count_output, + sizeof(unsigned int), + hipMemcpyDeviceToHost)); + HIP_CHECK(hipDeviceSynchronize()); + ASSERT_EQ(selected_count_output, expected_keys.size()); + + // Check if outputs are as expected + std::vector output_keys(input_keys.size()); + HIP_CHECK(hipMemcpy(output_keys.data(), + d_keys_input, + output_keys.size() * sizeof(output_keys[0]), + hipMemcpyDeviceToHost)); + std::vector output_values(input_values.size()); + HIP_CHECK(hipMemcpy(output_values.data(), + d_values_input, + output_values.size() * sizeof(output_values[0]), + hipMemcpyDeviceToHost)); + HIP_CHECK(hipDeviceSynchronize()); + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_eq(output_keys, expected_keys, expected_keys.size())); + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_eq(output_values, expected_values, expected_values.size())); + + hipFree(d_keys_input); + hipFree(d_values_input); + hipFree(d_selected_count_output); + hipFree(d_temp_storage); + + if(TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + } + } + } + + if(TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + class RocprimDeviceSelectLargeInputTests : public ::testing::TestWithParam> { public: const bool debug_synchronous = false; diff --git a/test/rocprim/test_invoke_result.cpp b/test/rocprim/test_invoke_result.cpp new file mode 100644 index 000000000..c956426c8 --- /dev/null +++ b/test/rocprim/test_invoke_result.cpp @@ -0,0 +1,123 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "../common_test_header.hpp" +#include "rocprim/types.hpp" + +#include +#include + +#include +#include + +template +struct device_plus +{ + __device__ inline constexpr T operator()(const T& a, const T& b) const + { + return a + b; + } +}; + +template +struct RocprimTypeInvokeResultParams +{ + using input_type = InputType; + using function = Function; + using expected_type = ExpectedType; +}; + +template +class RocprimInvokeResultBinOpTests : public ::testing::Test +{ +public: + using input_type = typename Params::input_type; + using function = typename Params::function; + using expected_type = typename Params::expected_type; +}; + +typedef ::testing::Types< + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams>, + RocprimTypeInvokeResultParams, bool>, + RocprimTypeInvokeResultParams, bool>, + RocprimTypeInvokeResultParams, bool>> + RocprimInvokeResultBinOpTestsParams; + +TYPED_TEST_SUITE(RocprimInvokeResultBinOpTests, RocprimInvokeResultBinOpTestsParams); + +TYPED_TEST(RocprimInvokeResultBinOpTests, HostInvokeResult) +{ + using input_type = typename TestFixture::input_type; + using binary_function = typename TestFixture::function; + using expected_type = typename TestFixture::expected_type; + + using resulting_type = rocprim::invoke_result_binary_op_t; + + // Compile and check on host + static_assert(std::is_same::value, + "Resulting type is not equal to expected type!"); +} + +template +struct static_cast_op +{ + __device__ inline constexpr ToType operator()(FromType a) const + { + return static_cast(a); + } +}; + +template +class RocprimInvokeResultUnOpTests : public ::testing::Test +{ +public: + using input_type = typename Params::input_type; + using function = typename Params::function; + using expected_type = typename Params::expected_type; +}; + +typedef ::testing::Types< + RocprimTypeInvokeResultParams, float>, + RocprimTypeInvokeResultParams, + rocprim::bfloat16>, + RocprimTypeInvokeResultParams>> + RocprimInvokeResultUnOpTestsParams; + +TYPED_TEST_SUITE(RocprimInvokeResultUnOpTests, RocprimInvokeResultUnOpTestsParams); + +TYPED_TEST(RocprimInvokeResultUnOpTests, HostInvokeResult) +{ + using input_type = typename TestFixture::input_type; + using unary_function = typename TestFixture::function; + using expected_type = typename TestFixture::expected_type; + + using resulting_type = rocprim::invoke_result_t; + + static_assert(std::is_same::value, + "Resulting type is not equal to expected type!"); +} diff --git a/test/rocprim/test_utils.hpp b/test/rocprim/test_utils.hpp index 13dea15b6..f5010e8eb 100644 --- a/test/rocprim/test_utils.hpp +++ b/test/rocprim/test_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,7 +21,6 @@ #ifndef TEST_TEST_UTILS_HPP_ #define TEST_TEST_UTILS_HPP_ -#include #include #include #include @@ -446,12 +445,8 @@ void iota(ForwardIt first, ForwardIt last, T value) } template -struct DeviceSelectWarpSize -{ - static constexpr unsigned value = ::rocprim::device_warp_size() >= LogicalWarpSize - ? LogicalWarpSize - : ::rocprim::device_warp_size(); -}; +__device__ constexpr bool device_test_enabled_for_warp_size_v + = ::rocprim::device_warp_size() >= LogicalWarpSize; } // end test_utils namespace diff --git a/test/rocprim/test_utils_assertions.hpp b/test/rocprim/test_utils_assertions.hpp index 45b65583e..073b783c9 100644 --- a/test/rocprim/test_utils_assertions.hpp +++ b/test/rocprim/test_utils_assertions.hpp @@ -245,8 +245,7 @@ void assert_bit_eq(const std::vector& result, const std::vector& expected) } } } - -#if defined(__GNUC__) || defined(__clang__) +#if ROCPRIM_HAS_INT128_SUPPORT inline void assert_bit_eq(const std::vector<__int128_t>& result, const std::vector<__int128_t>& expected) { diff --git a/test/rocprim/test_utils_types.hpp b/test/rocprim/test_utils_types.hpp index 1f793cbda..fb901adc9 100644 --- a/test/rocprim/test_utils_types.hpp +++ b/test/rocprim/test_utils_types.hpp @@ -158,9 +158,13 @@ typedef ::testing::Types< typedef ::testing::Types), block_param_type(uint8_t, short), - block_param_type(int8_t, float), + block_param_type(int8_t, float) +#if ROCPRIM_HAS_INT128_SUPPORT + , block_param_type(__uint128_t, short), - block_param_type(__int128_t, float)> + block_param_type(__int128_t, float) +#endif + > BlockParamsIntegralExtended; typedef ::testing::Types< diff --git a/test/rocprim/test_warp_exchange.cpp b/test/rocprim/test_warp_exchange.cpp index 769d6764a..0906b3681 100644 --- a/test/rocprim/test_warp_exchange.cpp +++ b/test/rocprim/test_warp_exchange.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,12 +25,13 @@ #include -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize, - class ExchangeOp -> +#include +#include +#include + +#include + +template struct Params { using type = T; @@ -128,70 +129,67 @@ struct ScatterToStripedOp } }; -using WarpExchangeTestParams = ::testing::Types< - Params, - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params, - Params, - - Params, - Params, - Params, - Params, - Params ->; - -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op -> -__global__ -__launch_bounds__(BlockSize) -void warp_exchange_kernel(T* d_input, - T* d_output) +using WarpExchangeTestParams + = ::testing::Types, + Params, + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + Params, + Params, + + Params, + Params, + Params, + Params, + Params>; + +template +__device__ auto warp_exchange_test(T* d_input, T* d_output) + -> std::enable_if_t> { - static_assert(BlockSize == LogicalWarpSize, - "BlockSize must be equal to LogicalWarpSize in this test"); - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value - >; - - ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage; + using warp_exchange_type = ::rocprim::warp_exchange; + constexpr unsigned int num_warps = ::rocprim::device_warp_size() / LogicalWarpSize; + ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; i++) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; } - Op{}(warp_exchange_type(), thread_data, storage); + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; + Op{}(warp_exchange_type(), thread_data, storage[warp_id]); for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } +template +__device__ auto warp_exchange_test(T* /*d_input*/, T* /*d_output*/) + -> std::enable_if_t> +{} + +template +__global__ void warp_exchange_kernel(T* d_input, T* d_output) +{ + warp_exchange_test(d_input, d_output); +} + template std::vector stripe_vector(const std::vector& v, const size_t warp_size, @@ -217,13 +215,16 @@ TYPED_TEST(WarpExchangeTest, WarpExchange) using T = typename TestFixture::params::type; constexpr unsigned int warp_size = TestFixture::params::warp_size; constexpr unsigned int items_per_thread = TestFixture::params::items_per_thread; - using exchange_op = typename TestFixture::params::exchange_op; - constexpr unsigned int block_size = warp_size; - constexpr unsigned int items_count = items_per_thread * block_size; + using exchange_op = typename TestFixture::params::exchange_op; - int device_id = test_common_utils::obtain_device_from_ctest(); + const int device_id = test_common_utils::obtain_device_from_ctest(); SKIP_IF_UNSUPPORTED_WARP_SIZE(warp_size, device_id); + unsigned int hw_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, hw_warp_size)); + const unsigned int block_size = hw_warp_size; + const unsigned int items_count = items_per_thread * block_size; + std::vector input(items_count); std::iota(input.begin(), input.end(), static_cast(0)); auto expected = input; @@ -238,20 +239,10 @@ TYPED_TEST(WarpExchangeTest, WarpExchange) HIP_CHECK(hipMemcpy(d_input, input.data(), items_count * sizeof(T), hipMemcpyHostToDevice)); T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); + HIP_CHECK(hipMemset(d_output, 0, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_exchange_kernel< - T, - block_size, - items_per_thread, - warp_size, - exchange_op - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output - ); + warp_exchange_kernel + <<>>(d_input, d_output); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -270,13 +261,12 @@ TYPED_TEST(WarpExchangeTest, WarpExchange) ASSERT_EQ(expected, output); } -using WarpExchangeScatterTestParams = ::testing::Types< - Params, - Params, - Params, - Params, - Params - >; +using WarpExchangeScatterTestParams = ::testing::Types, + Params, + Params, + Params, + Params, + Params>; template class WarpExchangeScatterTest : public ::testing::Test @@ -285,46 +275,46 @@ class WarpExchangeScatterTest : public ::testing::Test using params = Params; }; -template< - class T, - class OffsetT, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - class Op -> -__global__ -__launch_bounds__(BlockSize) -void warp_exchange_scatter_kernel(T* d_input, - T* d_output, - OffsetT* d_ranks) +template +__device__ auto warp_exchange_scatter_test(T* d_input, T* d_output, OffsetT* d_ranks) + -> std::enable_if_t> { - static_assert(BlockSize == LogicalWarpSize, - "BlockSize must be equal to LogicalWarpSize in this test"); - using warp_exchange_type = ::rocprim::warp_exchange< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value - >; + using warp_exchange_type = ::rocprim::warp_exchange; - ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage; + constexpr unsigned int num_warps = ::rocprim::device_warp_size() / LogicalWarpSize; + ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; OffsetT thread_ranks[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; i++) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; - thread_ranks[i] = d_ranks[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; + thread_ranks[i] = d_ranks[threadIdx.x * ItemsPerThread + i]; } - Op{}(warp_exchange_type(), thread_data, thread_ranks, storage); + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; + warp_exchange_type{}.scatter_to_striped(thread_data, + thread_data, + thread_ranks, + storage[warp_id]); for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } +template +__device__ auto warp_exchange_scatter_test(T* /*d_input*/, T* /*d_output*/, OffsetT* /*d_ranks*/) + -> std::enable_if_t> +{} + +template +__global__ void warp_exchange_scatter_kernel(T* d_input, T* d_output, OffsetT* d_ranks) +{ + warp_exchange_scatter_test(d_input, d_output, d_ranks); +} + TYPED_TEST_SUITE(WarpExchangeScatterTest, WarpExchangeScatterTestParams); TYPED_TEST(WarpExchangeScatterTest, WarpExchangeScatter) @@ -332,43 +322,42 @@ TYPED_TEST(WarpExchangeScatterTest, WarpExchangeScatter) using T = typename TestFixture::params::type; constexpr unsigned int warp_size = TestFixture::params::warp_size; constexpr unsigned int items_per_thread = TestFixture::params::items_per_thread; - using exchange_op = typename TestFixture::params::exchange_op; - constexpr unsigned int block_size = warp_size; - constexpr unsigned int items_count = items_per_thread * block_size; using OffsetT = unsigned short; - int device_id = test_common_utils::obtain_device_from_ctest(); + const int device_id = test_common_utils::obtain_device_from_ctest(); SKIP_IF_UNSUPPORTED_WARP_SIZE(warp_size, device_id); + unsigned int hw_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, hw_warp_size)); + const unsigned int block_size = hw_warp_size; + const unsigned int items_count = items_per_thread * block_size; std::vector input(items_count); std::iota(input.begin(), input.end(), static_cast(0)); - auto expected = input; - std::shuffle(input.begin(), input.end(), std::default_random_engine{std::random_device{}()}); - std::vector ranks(input.begin(), input.end()); + const auto expected = stripe_vector(input, warp_size, items_per_thread); + std::default_random_engine prng(std::random_device{}()); + for(auto it = input.begin(), end = input.end(); it != end; it += warp_size) + { + std::shuffle(it, it + warp_size, prng); + } + std::vector ranks(items_count); + std::transform(input.begin(), + input.end(), + ranks.begin(), + [](const T input_val) + { return static_cast(input_val) % (warp_size * items_per_thread); }); T* d_input{}; HIP_CHECK(hipMalloc(&d_input, items_count * sizeof(T))); HIP_CHECK(hipMemcpy(d_input, input.data(), items_count * sizeof(T), hipMemcpyHostToDevice)); T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); + HIP_CHECK(hipMemset(d_output, 0, items_count * sizeof(T))); OffsetT* d_ranks{}; HIP_CHECK(hipMalloc(&d_ranks, items_count * sizeof(OffsetT))); HIP_CHECK(hipMemcpy(d_ranks, ranks.data(), items_count * sizeof(OffsetT), hipMemcpyHostToDevice)); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_exchange_scatter_kernel< - T, - OffsetT, - block_size, - items_per_thread, - warp_size, - exchange_op - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output, d_ranks - ); + warp_exchange_scatter_kernel + <<>>(d_input, d_output, d_ranks); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -379,10 +368,5 @@ TYPED_TEST(WarpExchangeScatterTest, WarpExchangeScatter) HIP_CHECK(hipFree(d_output)); HIP_CHECK(hipFree(d_ranks)); - if(std::is_same::value) - { - expected = stripe_vector(expected, warp_size, items_per_thread); - } - ASSERT_EQ(expected, output); -} \ No newline at end of file +} diff --git a/test/rocprim/test_warp_load.cpp b/test/rocprim/test_warp_load.cpp index e1a14cfa0..c00e0501d 100644 --- a/test/rocprim/test_warp_load.cpp +++ b/test/rocprim/test_warp_load.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ #include "test_utils.hpp" #include +#include template< class T, @@ -78,29 +79,20 @@ using WarpLoadTestParams = ::testing::Types< Params >; -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_load_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_load_kernel(T* d_input, - T* d_output) +template +__device__ auto warp_load_test(T* d_input, T* d_output) + -> std::enable_if_t> { static_assert(BlockSize % LogicalWarpSize == 0, "LogicalWarpSize must be a divisor of BlockSize"); - using warp_load_type = ::rocprim::warp_load< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + using warp_load_type = ::rocprim::warp_load; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_load_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; @@ -109,35 +101,43 @@ void warp_load_kernel(T* d_input, for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_load_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_load_guarded_kernel(T* d_input, - T* d_output, - int valid_items, - T oob_default) +template +__device__ auto warp_load_test(T* /*d_input*/, T* /*d_output*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_load_kernel(T* d_input, T* d_output) +{ + warp_load_test(d_input, d_output); +} + +template +__device__ auto warp_load_guarded_test(T* d_input, T* d_output, int valid_items, T oob_default) + -> std::enable_if_t> { static_assert(BlockSize % LogicalWarpSize == 0, "LogicalWarpSize must be a divisor of BlockSize"); - using warp_load_type = ::rocprim::warp_load< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + using warp_load_type = ::rocprim::warp_load; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_load_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; @@ -152,10 +152,36 @@ void warp_load_guarded_kernel(T* d_input, for(unsigned int i = 0; i < ItemsPerThread; i++) { - d_output[hipThreadIdx_x * ItemsPerThread + i] = thread_data[i]; + d_output[threadIdx.x * ItemsPerThread + i] = thread_data[i]; } } +template +__device__ auto + warp_load_guarded_test(T* /*d_input*/, T* /*d_output*/, int /*valid_items*/, T /*oob_default*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_load_guarded_kernel(T* d_input, + T* d_output, + int valid_items, + T oob_default) +{ + warp_load_guarded_test(d_input, + d_output, + valid_items, + oob_default); +} + template std::vector stripe_vector(const std::vector& v, const size_t warp_size, @@ -197,19 +223,8 @@ TYPED_TEST(WarpLoadTest, WarpLoad) T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_load_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output - ); + warp_load_kernel + <<>>(d_input, d_output); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -251,20 +266,8 @@ TYPED_TEST(WarpLoadTest, WarpLoadGuarded) T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_load_guarded_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output, - valid_items, oob_default - ); + warp_load_guarded_kernel + <<>>(d_input, d_output, valid_items, oob_default); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); diff --git a/test/rocprim/test_warp_scan.hpp b/test/rocprim/test_warp_scan.hpp index 7eb807f23..99d82f33c 100644 --- a/test/rocprim/test_warp_scan.hpp +++ b/test/rocprim/test_warp_scan.hpp @@ -417,6 +417,149 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) } +typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScanWoInit) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = typename TestFixture::params::type; + // for bfloat16 and half we use double for host-side accumulation + using binary_op_type_host = typename test_utils::select_plus_operator_host::type; + binary_op_type_host binary_op_host; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; + using cast_type = typename test_utils::select_plus_operator_host::cast_type; + + // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + static constexpr size_t logical_warp_size = TestFixture::params::warp_size; + + // The different warp sizes + static constexpr size_t ws32{ROCPRIM_WARP_SIZE_32}; + static constexpr size_t ws64{ROCPRIM_WARP_SIZE_64}; + + // Block size of warp size 32 + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; + + // Block size of warp size 64 + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; + + unsigned int current_device_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); + + const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; + const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; + + // Check if warp size is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + { + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%d. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); + GTEST_SKIP(); + } + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + // Generate data + std::vector input = test_utils::get_random_data(size, 2, 50, seed_value); + std::vector output(size); + std::vector expected(input.size(), T(0)); + + // Calculate expected results on host + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + // expected[i * logical_warp_size] is unspecified because init is not passed + acc_type accumulator(input[i * logical_warp_size]); + + static_assert(logical_warp_size > 2, "logical_warp_size assumed to be at least 2."); + expected[i * logical_warp_size + 1] = static_cast(accumulator); + + for(size_t j = 2; j < logical_warp_size; j++) + { + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx - 1], accumulator); + expected[idx] = static_cast(accumulator); + } + } + + // Writing to device memory + T* device_input; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_input, + input.size() * sizeof(typename decltype(input)::value_type))); + T* device_output; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_output, + output.size() * sizeof(typename decltype(output)::value_type))); + + HIP_CHECK( + hipMemcpy(device_input, input.data(), input.size() * sizeof(T), hipMemcpyHostToDevice)); + + // Launching kernel + if(current_device_warp_size == ws32) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + warp_exclusive_scan_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws32), + 0, + 0, + device_input, + device_output); + } + else if(current_device_warp_size == ws64) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME( + warp_exclusive_scan_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws64), + 0, + 0, + device_input, + device_output); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Read from device memory + HIP_CHECK(hipMemcpy(output.data(), + device_output, + output.size() * sizeof(T), + hipMemcpyDeviceToHost)); + + // The first value of each logical warp has an unspecified result, expect whatever we got + // for those values to not fail the test. + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + expected[i * logical_warp_size] = output[i * logical_warp_size]; + } + + // Validating results + test_utils::assert_near(output, expected, test_utils::precision * logical_warp_size); + + HIP_CHECK(hipFree(device_input)); + HIP_CHECK(hipFree(device_output)); + } +} + //typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) { @@ -571,6 +714,175 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) } +typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScanWoInit) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = typename TestFixture::params::type; + // for bfloat16 and half we use double for host-side accumulation + using binary_op_type_host = typename test_utils::select_plus_operator_host::type; + binary_op_type_host binary_op_host; + using acc_type = typename test_utils::select_plus_operator_host::acc_type; + using cast_type = typename test_utils::select_plus_operator_host::cast_type; + + // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + static constexpr size_t logical_warp_size = TestFixture::params::warp_size; + + // The different warp sizes + static constexpr size_t ws32 = size_t(ROCPRIM_WARP_SIZE_32); + static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); + + // Block size of warp size 32 + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; + + // Block size of warp size 64 + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; + + unsigned int current_device_warp_size; + HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); + + const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; + const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; + + // Check if warp size is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + { + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%d. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); + GTEST_SKIP(); + } + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + // Generate data + std::vector input = test_utils::get_random_data(size, 2, 50, seed_value); + std::vector output(size); + std::vector output_reductions(size / logical_warp_size); + std::vector expected(input.size(), T(0)); + std::vector expected_reductions(output_reductions.size(), T(0)); + + // Calculate expected results on host + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + // expected[i * logical_warp_size] is unspecified because init is not passed + acc_type accumulator(input[i * logical_warp_size]); + + static_assert(logical_warp_size > 2, "logical_warp_size assumed to be at least 2."); + expected[i * logical_warp_size + 1] = static_cast(accumulator); + + for(size_t j = 2; j < logical_warp_size; j++) + { + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx - 1], accumulator); + expected[idx] = static_cast(accumulator); + } + + acc_type accumulator_reductions(0); + for(size_t j = 0; j < logical_warp_size; j++) + { + auto idx = i * logical_warp_size + j; + accumulator_reductions = binary_op_host(input[idx], accumulator_reductions); + expected_reductions[i] = static_cast(accumulator_reductions); + } + } + + // Writing to device memory + T* device_input; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_input, + input.size() * sizeof(typename decltype(input)::value_type))); + T* device_output; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_output, + output.size() * sizeof(typename decltype(output)::value_type))); + T* device_output_reductions; + HIP_CHECK(test_common_utils::hipMallocHelper( + &device_output_reductions, + output_reductions.size() * sizeof(typename decltype(output_reductions)::value_type))); + + HIP_CHECK( + hipMemcpy(device_input, input.data(), input.size() * sizeof(T), hipMemcpyHostToDevice)); + + // Launching kernel + if(current_device_warp_size == ws32) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(warp_exclusive_scan_reduce_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws32), + 0, + 0, + device_input, + device_output, + device_output_reductions); + } + else if(current_device_warp_size == ws64) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(warp_exclusive_scan_reduce_wo_init_kernel), + dim3(grid_size), + dim3(block_size_ws64), + 0, + 0, + device_input, + device_output, + device_output_reductions); + } + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Read from device memory + HIP_CHECK(hipMemcpy(output.data(), + device_output, + output.size() * sizeof(T), + hipMemcpyDeviceToHost)); + + HIP_CHECK(hipMemcpy(output_reductions.data(), + device_output_reductions, + output_reductions.size() * sizeof(T), + hipMemcpyDeviceToHost)); + + // The first value of each logical warp has an unspecified result, expect whatever we got + // for those values to not fail the test. + for(size_t i = 0; i < input.size() / logical_warp_size; i++) + { + expected[i * logical_warp_size] = output[i * logical_warp_size]; + } + + // Validating results + test_utils::assert_near(output, expected, test_utils::precision * logical_warp_size); + test_utils::assert_near(output_reductions, + expected_reductions, + test_utils::precision * logical_warp_size); + + HIP_CHECK(hipFree(device_input)); + HIP_CHECK(hipFree(device_output)); + HIP_CHECK(hipFree(device_output_reductions)); + } +} + typed_test_def(RocprimWarpScanTests, name_suffix, Scan) { int device_id = test_common_utils::obtain_device_from_ctest(); diff --git a/test/rocprim/test_warp_scan.kernels.hpp b/test/rocprim/test_warp_scan.kernels.hpp index 11f859ec6..875389058 100644 --- a/test/rocprim/test_warp_scan.kernels.hpp +++ b/test/rocprim/test_warp_scan.kernels.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -128,6 +128,49 @@ void warp_exclusive_scan_reduce_kernel( } } +template +__global__ __launch_bounds__(BlockSize) void warp_exclusive_scan_wo_init_kernel(T* device_input, + T* device_output) +{ + static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; + + const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); + const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; + + T value = device_input[global_index]; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[block_warps_no]; + wscan_t().exclusive_scan(value, value, storage[block_warp_id]); + + device_output[global_index] = value; +} + +template +__global__ __launch_bounds__(BlockSize) void warp_exclusive_scan_reduce_wo_init_kernel( + T* device_input, T* device_output, T* device_output_reductions) +{ + static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; + + const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); + const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; + const unsigned int lane_id = threadIdx.x % LogicalWarpSize; + const unsigned int global_warp_id = global_index / LogicalWarpSize; + + T value = device_input[global_index]; + T reduction; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[block_warps_no]; + wscan_t().exclusive_scan(value, value, storage[block_warp_id], reduction); + + device_output[global_index] = value; + if(lane_id == 0) + { + device_output_reductions[global_warp_id] = reduction; + } +} + template< class T, unsigned int BlockSize, diff --git a/test/rocprim/test_warp_store.cpp b/test/rocprim/test_warp_store.cpp index 2d6a0430a..63aab5355 100644 --- a/test/rocprim/test_warp_store.cpp +++ b/test/rocprim/test_warp_store.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,7 @@ #include "test_utils.hpp" #include +#include template< class T, @@ -78,66 +79,66 @@ using WarpStoreTestParams = ::testing::Types< Params >; -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_store_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_store_kernel(T* d_input, - T* d_output) +template +__device__ auto warp_store_test(T* d_input, T* d_output) + -> std::enable_if_t> { - using warp_store_type = ::rocprim::warp_store< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + using warp_store_type = ::rocprim::warp_store; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_store_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; ++i) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; } warp_store_type().store(d_output + warp_id * tile_size, thread_data, storage[warp_id]); } -template< - class T, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int LogicalWarpSize, - ::rocprim::warp_store_method Method -> -__global__ -__launch_bounds__(BlockSize) -void warp_store_guarded_kernel(T* d_input, - T* d_output, - int valid_items) +template +__device__ auto warp_store_test(T* /*d_input*/, T* /*d_output*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_store_kernel(T* d_input, T* d_output) { - using warp_store_type = ::rocprim::warp_store< - T, - ItemsPerThread, - test_utils::DeviceSelectWarpSize::value, - Method - >; + warp_store_test(d_input, d_output); +} + +template +__device__ auto warp_store_guarded_test(T* d_input, T* d_output, int valid_items) + -> std::enable_if_t> +{ + using warp_store_type = ::rocprim::warp_store; constexpr unsigned int tile_size = ItemsPerThread * LogicalWarpSize; constexpr unsigned int num_warps = BlockSize / LogicalWarpSize; - const unsigned int warp_id = hipThreadIdx_x / LogicalWarpSize; + const unsigned int warp_id = threadIdx.x / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_store_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; ++i) { - thread_data[i] = d_input[hipThreadIdx_x * ItemsPerThread + i]; + thread_data[i] = d_input[threadIdx.x * ItemsPerThread + i]; } warp_store_type().store(d_output + warp_id * tile_size, @@ -147,6 +148,29 @@ void warp_store_guarded_kernel(T* d_input, ); } +template +__device__ auto warp_store_guarded_test(T* /*d_input*/, T* /*d_output*/, int /*valid_items*/) + -> std::enable_if_t> +{} + +template +__global__ __launch_bounds__(BlockSize) void warp_store_guarded_kernel(T* d_input, + T* d_output, + int valid_items) +{ + warp_store_guarded_test(d_input, + d_output, + valid_items); +} + template std::vector stripe_vector(const std::vector& v, const size_t warp_size, @@ -187,19 +211,8 @@ TYPED_TEST(WarpStoreTest, WarpLoad) T* d_output{}; HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_store_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output - ); + warp_store_kernel + <<>>(d_input, d_output); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -241,19 +254,8 @@ TYPED_TEST(WarpStoreTest, WarpStoreGuarded) HIP_CHECK(hipMalloc(&d_output, items_count * sizeof(T))); HIP_CHECK(hipMemset(d_output, 0, items_count * sizeof(T))); - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_store_guarded_kernel< - T, - block_size, - items_per_thread, - warp_size, - method - > - ), - dim3(1), dim3(block_size), 0, 0, - d_input, d_output, valid_items - ); + warp_store_guarded_kernel + <<>>(d_input, d_output, valid_items); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize());