Skip to content

Commit

Permalink
Added check to partition kernel if size is smaller than items_per_blo…
Browse files Browse the repository at this point in the history
…ck (#538) (#546)

Co-authored-by: Nick Breed <[email protected]>
Co-authored-by: Nick Breed <[email protected]>
  • Loading branch information
3 people authored Apr 18, 2024
1 parent 435f7f4 commit 85253f8
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions rocprim/include/rocprim/device/detail/device_merge.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -68,24 +68,23 @@ void partition_kernel_impl(IndexIterator indices,
const unsigned int spacing,
BinaryFunction compare_function)
{
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
const unsigned int flat_block_size = ::rocprim::detail::block_size<0>();
const unsigned int input_size = input1_size + input2_size;
const unsigned int id = flat_block_id * flat_block_size + flat_id;
const unsigned int partition_id = id * spacing;
const unsigned int partitions = (input_size + spacing - 1) / spacing;

unsigned int id = flat_block_id * flat_block_size + flat_id;
if(id > partitions)
{
return;
}

unsigned int partition_id = id * spacing;
size_t diag = min(static_cast<size_t>(partition_id), input1_size + input2_size);

unsigned int begin =
merge_path(
keys_input1,
keys_input2,
input1_size,
input2_size,
diag,
compare_function
);
unsigned int begin
= merge_path(keys_input1, keys_input2, input1_size, input2_size, diag, compare_function);

indices[id] = begin;
}
Expand Down Expand Up @@ -310,8 +309,10 @@ void merge_kernel_impl(IndexIterator indices,
const unsigned int valid_in_last_block = count - block_offset;
const bool is_incomplete_block = valid_in_last_block < items_per_block;

const unsigned int p1 = indices[flat_block_id];
const unsigned int p2 = indices[flat_block_id + 1];
const unsigned int partitions = (count + items_per_block - 1) / items_per_block;

const unsigned int p1 = indices[rocprim::min(flat_block_id, partitions)];
const unsigned int p2 = indices[rocprim::min(flat_block_id + 1, partitions)];

range_t range =
compute_range(
Expand Down

0 comments on commit 85253f8

Please sign in to comment.