Skip to content

Commit

Permalink
[compute/cker] Set default gamma,beta in RmsNorm
Browse files Browse the repository at this point in the history
- If no gamma and beta are provided, use default gamma(1.0) and beta(0.0)

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim committed Sep 30, 2024
1 parent 3336cc4 commit 05593ea
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions compute/cker/include/cker/operation/RmsNorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ inline void RmsNorm(const RmsNormParams &params, const Shape &input_shape, const
double rms = std::sqrt((square_sum / channels) + params.epsilon);
for (int32_t channel = 0; channel < channels; channel++)
{
double gamma = gamma_data[channel];
double beta = beta_data[channel];
double gamma = (gamma_data ? gamma_data[channel] : 1.0);
double beta = (beta_data ? beta_data[channel] : 0.0);
output_data[Offset(output_shape, batch, height, width, channel)] =
(gamma * (input_data[Offset(input_shape, batch, height, width, channel)] / rms) + beta);
}
Expand Down

0 comments on commit 05593ea

Please sign in to comment.