Skip to content

Commit

Permalink
[luci/pass] Remove duplicate code
Browse files Browse the repository at this point in the history
This removes duplicate codes in ConvertNCHWToNHWC.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening committed Jul 26, 2024
1 parent cdd5a91 commit 131abd7
Showing 1 changed file with 2 additions and 42 deletions.
44 changes: 2 additions & 42 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,46 +170,6 @@ bool check_4d_transpose(loco::Node *node, const std::vector<int32_t> &indices)
return true;
}

luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
const std::vector<int32_t> &indices)
{
assert(indices.size() == 4);

auto name = node->name();
assert(name.length() > 0);

auto perm = node->graph()->nodes()->create<luci::CircleConst>();
perm->dtype(loco::DataType::S32);
perm->size<loco::DataType::S32>(4);
perm->rank(1);
perm->dim(0) = 4;
for (uint32_t i = 0; i < 4; i++)
perm->at<loco::DataType::S32>(i) = indices[i];
perm->shape_status(luci::ShapeStatus::VALID);

auto make_string = [](const std::vector<int32_t> &nums) {
std::string str;
for (auto num : nums)
{
if (str.length() > 0)
str += ".";
str += std::to_string(num);
}
return str;
};

auto str_indices = make_string(indices);

perm->name(name + "/Transpose_" + str_indices + "/perm");

auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
trans->perm(perm);
trans->name(name + "/Transpose_" + str_indices);
luci::add_origin(trans, luci::get_origin(node));

return trans;
}

luci::CircleTranspose *create_Nd_transpose(luci::CircleNode *node,
const std::vector<int32_t> &indices)
{
Expand Down Expand Up @@ -280,12 +240,12 @@ luci::CircleConst *create_nhwc_axis(luci::CircleConst *axis)

luci::CircleTranspose *create_post_transpose(luci::CircleNode *node)
{
return create_4d_transpose(node, {0, 3, 1, 2});
return create_Nd_transpose(node, {0, 3, 1, 2});
}

luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node)
{
return create_4d_transpose(node, {0, 2, 3, 1});
return create_Nd_transpose(node, {0, 2, 3, 1});
}

bool check_4d_reshape(loco::Node *node, const std::vector<int32_t> indices)
Expand Down

0 comments on commit 131abd7

Please sign in to comment.