Skip to content

Commit

Permalink
[luci/partition] Support RmsNorm operation
Browse files Browse the repository at this point in the history
This commit supports RmsNorm operation in luci partition.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim committed Sep 13, 2024
1 parent d4d3bf9 commit f43c9cd
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/luci/partition/include/luci/ConnectNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class ConnectNode final : public luci::CircleNodeVisitor<void>
void visit(const luci::CircleBCQGather *) final;
void visit(const luci::CircleGRU *) final;
void visit(const luci::CircleInstanceNorm *) final;
void visit(const luci::CircleRmsNorm *) final;

// NOTE CircleInput and CircleOutput are not handled here as these need
// link with graph I/O
Expand Down
42 changes: 42 additions & 0 deletions compiler/luci/partition/src/Nodes/CircleRmsNorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/ConnectNode.h"

namespace
{

void connect(luci::ConnectNode *cn, const luci::CircleRmsNorm *node)
{
auto *cloned = loco::must_cast<luci::CircleRmsNorm *>(cn->find_clone(node));

luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
luci::CircleNode *gamma = loco::must_cast<luci::CircleNode *>(node->gamma());
luci::CircleNode *beta = loco::must_cast<luci::CircleNode *>(node->beta());

cloned->input(cn->find_clone(input));
cloned->gamma(cn->find_clone(gamma));
cloned->beta(cn->find_clone(beta));
}

} // namespace

namespace luci
{

void ConnectNode::visit(const luci::CircleRmsNorm *node) { connect(this, node); }

} // namespace luci
97 changes: 97 additions & 0 deletions compiler/luci/partition/src/Nodes/CircleRmsNorm.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/ConnectNode.h"

#include "ConnectNode.test.h"

#include <luci/Service/CircleNodeClone.h>

#include <gtest/gtest.h>

namespace
{

using namespace luci::test;

class NodeGraphlet : public NodeGraphletT<luci::CircleRmsNorm>
{
public:
NodeGraphlet() = default;

public:
void init(loco::Graph *g) override { NodeGraphletT<luci::CircleRmsNorm>::init(g); }
};

class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
{
public:
TestNodeGraph() = default;

public:
void init(const ShapeU32 shape)
{
TestIsOGraph<3>::init({shape, shape, shape}, shape);
NodeGraphlet::init(g());

node()->input(input(0));
node()->gamma(input(1));
node()->beta(input(2));

output()->from(node());
}
};

} // namespace

TEST(ConnectNodeTest, connect_RmsNorm)
{
TestNodeGraph tng;
tng.init({2, 3});

ConnectionTestHelper cth;
cth.prepare_inputs(&tng);

auto *node = tng.node();
ASSERT_NO_THROW(loco::must_cast<luci::CircleRmsNorm *>(node));

auto *clone = luci::clone_node(node, cth.graph_clone());
ASSERT_NO_THROW(loco::must_cast<luci::CircleRmsNorm *>(clone));

cth.clone_connect(node, clone);

ASSERT_EQ(3, clone->arity());
ASSERT_EQ(cth.inputs(0), clone->arg(0));
ASSERT_EQ(cth.inputs(1), clone->arg(1));
ASSERT_EQ(cth.inputs(2), clone->arg(2));
}

TEST(ConnectNodeTest, connect_RmsNorm_NEG)
{
TestNodeGraph tng;
tng.init({2, 3});

ConnectionTestHelper cth;
cth.prepare_inputs_miss(&tng);

auto *node = tng.node();
ASSERT_NO_THROW(loco::must_cast<luci::CircleRmsNorm *>(node));

auto *clone = luci::clone_node(node, cth.graph_clone());
ASSERT_NO_THROW(loco::must_cast<luci::CircleRmsNorm *>(clone));

EXPECT_ANY_THROW(cth.clone_connect(node, clone));
}

0 comments on commit f43c9cd

Please sign in to comment.