Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] Fix performance issue in bilinear interpolation's backward kernel. #68541

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 112 additions & 38 deletions paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
T* lambda2,
T src_x,
const int in_img_x) {
src_x = (src_x > static_cast<T>(0)) ? src_x : static_cast<T>(0);
*in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
*lambda1 = static_cast<T>(static_cast<MT>(src_x) - *in_img_idx);
*lambda2 = static_cast<T>(1.0) - *lambda1;
src_x = max(src_x, static_cast<T>(0));
T src_x_floor = floorf(src_x);
T frac_part = src_x - src_x_floor;
*lambda1 = frac_part;
*lambda2 = static_cast<T>(1) - frac_part;
*in_img_idx = static_cast<int>(src_x_floor);
*x_id = (*in_img_idx < in_img_x - 1);
}

template <typename T>
Expand Down Expand Up @@ -360,42 +361,115 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
const T* __restrict__ out,
const float align_type_value) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x;
int num_out = n * num_channels * out_h * out_w;
int num_in = n * num_channels * in_h * in_w;
const int stride = blockDim.x * gridDim.x;
const int num_out = n * num_channels * out_h * out_w;
const int num_in = n * num_channels * in_h * in_w;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;

for (; index < num_out; index += stride) {
int index_tmp = index;
int w2 = index_tmp % out_w;
index_tmp /= out_w;
int h2 = index_tmp % out_h;
int nc = index_tmp / out_h;
// Restricted parallelism if ratio_w is over threshold
// to avoid atomic contention overhead.
if (ratio_w < 0.5) [[likely]] { // NOLINT
if (index < num_in) {
int index_tmp = index;
const int w1 = index_tmp % in_w;
index_tmp /= in_w;
const int h1 = index_tmp % in_h;
const int nc = index_tmp / in_h;

MT d2val_sum = 0.0f;

// Precompute constants
const MT inv_ratio_h = 1.0f / ratio_h;
const MT inv_ratio_w = 1.0f / ratio_w;

// Compute the range of output pixels (h2_min, h2_max) that could affect
// input pixel h1
const MT h2r_min =
(h1 - 1 + align_type_value) * inv_ratio_h - align_type_value;
const int h2_min = max(static_cast<int>(ceilf(h2r_min)), 0);

const MT h2r_max =
(h1 + 1 + align_type_value) * inv_ratio_h - align_type_value;
const int h2_max = min(static_cast<int>(floorf(h2r_max)), out_h - 1);

// Compute the range of output pixels (w2_min, w2_max) that could affect
// input pixel w1
const MT w2r_min =
(w1 - 1 + align_type_value) * inv_ratio_w - align_type_value;
const int w2_min = max(static_cast<int>(ceilf(w2r_min)), 0);

const MT w2r_max =
(w1 + 1 + align_type_value) * inv_ratio_w - align_type_value;
const int w2_max = min(static_cast<int>(floorf(w2r_max)), out_w - 1);

for (int h2 = h2_min; h2 <= h2_max; ++h2) {
const MT src_y = ratio_h * (h2 + align_type_value) - align_type_value;
int h1_, y_id;
MT h1lambda, h0lambda;
PreCalculatorForLinearInterpInputIndex(
&h1_, &y_id, &h1lambda, &h0lambda, src_y, in_h);

if (h1 != h1_ && h1 != h1_ + y_id) [[unlikely]] {
continue;
}

int h1, y_id;
MT h1lambda, h0lambda;
MT src_y =
static_cast<MT>(ratio_h * (h2 + align_type_value) - align_type_value);
for (int w2 = w2_min; w2 <= w2_max; ++w2) {
int w1_, x_id;
const MT src_x = ratio_w * (w2 + align_type_value) - align_type_value;
MT w1lambda, w0lambda;
PreCalculatorForLinearInterpInputIndex(
&w1_, &x_id, &w1lambda, &w0lambda, src_x, in_w);
if (w1 != w1_ && w1 != w1_ + x_id) [[unlikely]] {
continue;
}

PreCalculatorForLinearInterpInputIndex(
&h1, &y_id, &h1lambda, &h0lambda, src_y, in_h);
int w1, x_id;
MT w1lambda, w0lambda;
MT src_x =
static_cast<MT>(ratio_w * (w2 + align_type_value) - align_type_value);
PreCalculatorForLinearInterpInputIndex(
&w1, &x_id, &w1lambda, &w0lambda, src_x, in_w);

MT d2val = static_cast<MT>(out[index]);

phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
static_cast<T>(h0lambda * w0lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
static_cast<T>(h0lambda * w1lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
static_cast<T>(h1lambda * w0lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
static_cast<T>(h1lambda * w1lambda * d2val));
const MT grad_output = out[nc * out_h * out_w + h2 * out_w + w2];

float hlambda = (h1 == h1_) ? h0lambda : 0.0f;
hlambda += (h1 == h1_ + y_id) ? h1lambda : 0.0f;

float wlambda = (w1 == w1_) ? w0lambda : 0.0f;
wlambda += (w1 == w1_ + x_id) ? w1lambda : 0.0f;

d2val_sum += hlambda * wlambda * grad_output;
}
}
in[index] = static_cast<T>(d2val_sum);
}
} else [[unlikely]] { // NOLINT
for (; index < num_out; index += stride) {
int index_tmp = index;
int w2 = index_tmp % out_w;
index_tmp /= out_w;
int h2 = index_tmp % out_h;
int nc = index_tmp / out_h;

int h1, y_id;
MT h1lambda, h0lambda;
MT src_y =
static_cast<MT>(ratio_h * (h2 + align_type_value) - align_type_value);

PreCalculatorForLinearInterpInputIndex(
&h1, &y_id, &h1lambda, &h0lambda, src_y, in_h);
int w1, x_id;
MT w1lambda, w0lambda;
MT src_x =
static_cast<MT>(ratio_w * (w2 + align_type_value) - align_type_value);
PreCalculatorForLinearInterpInputIndex(
&w1, &x_id, &w1lambda, &w0lambda, src_x, in_w);

MT d2val = static_cast<MT>(out[index]);

phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
static_cast<T>(h0lambda * w0lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
static_cast<T>(h0lambda * w1lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
static_cast<T>(h1lambda * w0lambda * d2val));
phi::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
static_cast<T>(h1lambda * w1lambda * d2val));
}
}
}

Expand Down