Skip to content

Commit

Permalink
[compute/cker] Unit test of RmsNorm added.
Browse files Browse the repository at this point in the history
This commit adds unit test of RmsNorm in cker.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim committed Sep 27, 2024
1 parent 6d93fdc commit 3336cc4
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
8 changes: 6 additions & 2 deletions compute/cker/include/cker/operation/RmsNorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ inline void RmsNorm(const RmsNormParams &params, const Shape &input_shape, const
const int32_t widths = MatchingDim(input_shape, 2, output_shape, 2);
const int32_t channels = MatchingDim(input_shape, 3, output_shape, 3);

UNUSED_RELEASE(gamma_shape);
UNUSED_RELEASE(beta_shape);
if (gamma_shape.DimensionsCount() != 1 ||
gamma_shape.Dims(0) != input_shape.Dims(input_shape.DimensionsCount() - 1))
throw std::runtime_error("cker::RmsNorm: Unmatched gamma shape");
if (beta_shape.DimensionsCount() != 1 ||
beta_shape.Dims(0) != input_shape.Dims(input_shape.DimensionsCount() - 1))
throw std::runtime_error("cker::RmsNorm: Unmatched beta shape");

for (int32_t batch = 0; batch < batches; batch++)
{
Expand Down
100 changes: 100 additions & 0 deletions compute/cker/src/RmsNorm.test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cker/operation/RmsNorm.h>

#include <gtest/gtest.h>
#include <vector>

TEST(CKer_Operation, RmsNorm)
{
// Simple
{
std::vector<float> input = {0, 1, 2, 3};
nnfw::cker::Shape input_shape{1, 2, 2, 1};

std::vector<float> expected_output = {0, 1, 1, 1};
std::vector<float> output(expected_output.size());
nnfw::cker::Shape output_shape{1, 2, 2, 1};

std::vector<float> gamma = {1};
nnfw::cker::Shape gamma_shape{1};

std::vector<float> beta = {0};
nnfw::cker::Shape beta_shape{1};

nnfw::cker::RmsNormParams param;
param.epsilon = 0.00001f;

nnfw::cker::RmsNorm(param, input_shape, input.data(), gamma_shape, gamma.data(), beta_shape,
beta.data(), output_shape, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}

// Default gamma and beta
{
std::vector<float> input = {0, 1, 2, 3, 4, 5, 6, 7};
nnfw::cker::Shape input_shape{1, 2, 2, 2};

std::vector<float> expected_output = {0, 1.412802, 0.784404, 1.176606,
0.883431, 1.104288, 0.920347, 1.073738};
std::vector<float> output(expected_output.size());
nnfw::cker::Shape output_shape{1, 2, 2, 2};

std::vector<float> gamma = {1, 1};
nnfw::cker::Shape gamma_shape{2};

std::vector<float> beta = {0, 0};
nnfw::cker::Shape beta_shape{2};

nnfw::cker::RmsNormParams param;
param.epsilon = 0.001f;

nnfw::cker::RmsNorm(param, input_shape, input.data(), gamma_shape, gamma.data(), beta_shape,
beta.data(), output_shape, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}
}

TEST(CKer_Operation, neg_RmsNormWrongGammaDims)
{
{
std::vector<float> input = {0, 1, 2, 3, 4, 5, 6, 7};
nnfw::cker::Shape input_shape{1, 2, 2, 2};

std::vector<float> expected_output = {0, 1.412802, 0.784404, 1.176606,
0.883431, 1.104288, 0.920347, 1.073738};
std::vector<float> output(expected_output.size());
nnfw::cker::Shape output_shape{1, 2, 2, 2};

std::vector<float> gamma = {1};
nnfw::cker::Shape gamma_shape{1};

std::vector<float> beta = {0, 0};
nnfw::cker::Shape beta_shape{2};

nnfw::cker::RmsNormParams param;
param.epsilon = 0.001f;

EXPECT_ANY_THROW(nnfw::cker::RmsNorm(param, input_shape, input.data(), gamma_shape,
gamma.data(), beta_shape, beta.data(), output_shape,
output.data()));
}
}

0 comments on commit 3336cc4

Please sign in to comment.