Skip to content

Commit

Permalink
Cherry-pick: Optimize block_reduce_warp_reduce when block size is the…
Browse files Browse the repository at this point in the history
… same as warp size (#599)

* Optimize block_reduce_warp_reduce when block size == warp size

* Make conditional constexpr
  • Loading branch information
stanleytsang-amd authored Aug 27, 2024
1 parent eab1eed commit 93501cf
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
Documentation for rocPRIM is available at
[https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/).

## Unreleased rocPRIM-3.2.0 for ROCm 6.2.0
## rocPRIM-3.2.1 for ROCm 6.2.1

### Optimizations
* Improved performance of block_reduce_warp_reduce when warp size == block size.

## rocPRIM-3.2.0 for ROCm 6.2.0

### Additions

Expand Down
66 changes: 37 additions & 29 deletions rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,25 @@ class block_reduce_warp_reduce
input, output, num_valid, reduce_op
);

// i-th warp will have its partial stored in storage_.warp_partials[i-1]
if(lane_id == 0)
// Final reduction across warps is only required if there is more than 1 warp
if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1)
{
storage_.warp_partials[warp_id] = output;
}
::rocprim::syncthreads();

if(flat_tid < warps_no_)
{
// Use warp partial to calculate the final reduce results for every thread
auto warp_partial = storage_.warp_partials[lane_id];

warp_reduce<!warps_no_is_pow_of_two_, warp_reduce_output_type>(
warp_partial, output, warps_no_, reduce_op
);
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
if(lane_id == 0)
{
storage_.warp_partials[warp_id] = output;
}
::rocprim::syncthreads();

if(flat_tid < warps_no_)
{
// Use warp partial to calculate the final reduce results for every thread
auto warp_partial = storage_.warp_partials[lane_id];

warp_reduce<!warps_no_is_pow_of_two_, warp_reduce_output_type>(
warp_partial, output, warps_no_, reduce_op
);
}
}
}

Expand Down Expand Up @@ -244,22 +248,26 @@ class block_reduce_warp_reduce
input, output, num_valid, reduce_op
);

// i-th warp will have its partial stored in storage_.warp_partials[i-1]
if(lane_id == 0)
// Final reduction across warps is only required if there is more than 1 warp
if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1)
{
storage_.warp_partials[warp_id] = output;
}
::rocprim::syncthreads();

if(flat_tid < warps_no_)
{
// Use warp partial to calculate the final reduce results for every thread
auto warp_partial = storage_.warp_partials[lane_id];

unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_;
warp_reduce_output_type().reduce(
warp_partial, output, valid_warps_no, reduce_op
);
// i-th warp will have its partial stored in storage_.warp_partials[i-1]
if(lane_id == 0)
{
storage_.warp_partials[warp_id] = output;
}
::rocprim::syncthreads();

if(flat_tid < warps_no_)
{
// Use warp partial to calculate the final reduce results for every thread
auto warp_partial = storage_.warp_partials[lane_id];

unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_;
warp_reduce_output_type().reduce(
warp_partial, output, valid_warps_no, reduce_op
);
}
}
}
};
Expand Down

0 comments on commit 93501cf

Please sign in to comment.