From d7c66185c805ea9b367b7e3928d983b6fdcbf6bb Mon Sep 17 00:00:00 2001 From: sseung Date: Wed, 2 Oct 2024 20:09:25 +0900 Subject: [PATCH 1/5] [luci/pass] Add TaggedShapeAnalyzer::init() in RmUnnTransNetPass This PR adds TaggedShapeAnalyzer::init() function. The init() function explicitly checks some conditions that have be met for the analyzer. ONE-DCO-1.0-Signed-off-by: seunghui youn --- .../src/RemoveUnnecessaryTransposeNetPass.cpp | 154 +++++++++++++----- 1 file changed, 115 insertions(+), 39 deletions(-) diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp index f07d4fa18eb..5a607309a46 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -24,6 +24,38 @@ namespace { +#define RET_FALSE_UNLESS(condition) \ + if (not(condition)) \ + return false; + +std::vector extract_shape(const luci::CircleNode *node) +{ + std::vector shape; + auto rank = node->rank(); + for (auto i = 0u; i < rank; ++i) + { + shape.push_back(static_cast(node->dim(i).value())); + } + return shape; +}; + +template +std::vector extract_const(const luci::CircleConst *const_node) +{ + static_assert(DTYPE == loco::DataType::S32 || DTYPE == loco::DataType::S16 || + DTYPE == loco::DataType::S8, + "unsupported data type"); + + std::vector values; + auto size = const_node->size(); + for (auto i = 0u; i < size; ++i) + { + auto v = const_node->at(i); + values.push_back(static_cast(v)); + } + return values; +}; + struct TagDim final { int32_t value; @@ -36,8 +68,10 @@ class TaggedShapeAnalyzer final { public: template - bool can_remove_transposes(const luci::CircleTranspose *f_tr, const luci::CircleReshape *m_rs, - const luci::CircleTranspose *b_tr); + bool init(const luci::CircleTranspose *, const luci::CircleReshape *, + const luci::CircleTranspose *); + + template bool can_remove_transposes(); private: void init_shape_with_tag(const luci::CircleNode *); @@ -48,6 +82,17 @@ class TaggedShapeAnalyzer final bool verify_tag() const; +private: + const luci::CircleNode *_in = nullptr; + const luci::CircleTranspose *_front_transpose = nullptr; + const luci::CircleReshape *_mid_reshape = nullptr; + const luci::CircleTranspose *_back_transpose = nullptr; + + std::vector _in_shape; + std::vector _front_perm; + std::vector _mid_shape; + std::vector _back_perm; + const uint8_t START_TAG = 0; TagShape _shape; }; @@ -215,9 +260,65 @@ bool TaggedShapeAnalyzer::verify_tag() const return true; } +/** + * @brief Initialize the class member and checks under conditions + * + * Condtiions that have to be met for analyzer + * c1: input rank >= output rank + * c2: The 'perm' of tranpose should be a CircleConst* type + * c3: The shapes of the given nodes should be all known + * + * @return True, if all conditions are satisfied and class members are initialized successfully + * False, otherwise + */ +template +bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, + const luci::CircleReshape *mid_reshape, + const luci::CircleTranspose *back_transpose) +{ + _in = dynamic_cast(front_transpose->a()); + _front_transpose = front_transpose; + _mid_reshape = mid_reshape; + _back_transpose = back_transpose; + + // check c1 + RET_FALSE_UNLESS(_in->rank() >= _back_transpose->rank()); + + const auto front_perm = dynamic_cast(_front_transpose->perm()); + const auto back_perm = dynamic_cast(_back_transpose->perm()); + + // check c2 + RET_FALSE_UNLESS(front_perm != nullptr); + RET_FALSE_UNLESS(front_perm->dtype() == DType); + RET_FALSE_UNLESS(back_perm != nullptr); + RET_FALSE_UNLESS(back_perm->dtype() == DType); + + _in_shape = extract_shape(_in); + _front_perm = extract_const(front_perm); + _mid_shape = extract_shape(_mid_reshape); + _back_perm = extract_const(back_perm); + + auto all_known = [](const std::vector &v) -> bool { + for (auto i : v) + if (i <= 0) + return false; + return true; + }; + + // check c3 + RET_FALSE_UNLESS(all_known(_in_shape)); + RET_FALSE_UNLESS(all_known(extract_shape(_front_transpose))); + RET_FALSE_UNLESS(all_known(_mid_shape)); + RET_FALSE_UNLESS(all_known(extract_shape(_back_transpose))); + + return true; +} + /** * @brief check 'Transpose-Reshape-Transpose' can be replaced by one 'Reshape'. * + * @warning '@init' have to be called first + * * @example * Let's explain how analyzer check Transpose-Reshape-Transpose pattern with an exact example. * @@ -255,24 +356,21 @@ bool TaggedShapeAnalyzer::verify_tag() const * Transpose has no effect in final shape, which they can be removed as * unnecessary Ops. */ -template -bool TaggedShapeAnalyzer::can_remove_transposes(const luci::CircleTranspose *f_tr, - const luci::CircleReshape *m_rs, - const luci::CircleTranspose *b_tr) +template bool TaggedShapeAnalyzer::can_remove_transposes() { - assert(loco::must_cast(f_tr->perm())->dtype() == DType); - assert(loco::must_cast(b_tr->perm())->dtype() == DType); - - const luci::CircleNode *in_tensor = loco::must_cast(f_tr->a()); + // TODO: Update methods to use std::vector intead of CircleNode + // For example, + // init_shape_with_tag(_in_shape); + // analyze_transpose(_fornt_perm); - init_shape_with_tag(in_tensor); + init_shape_with_tag(_in); - analyze_transpose(f_tr); + analyze_transpose(_front_transpose); - if (not analyze_reshape(m_rs)) + if (not analyze_reshape(_mid_reshape)) return false; - analyze_transpose(b_tr); + analyze_transpose(_back_transpose); if (not verify_tag()) return false; @@ -339,34 +437,12 @@ bool remove_unnecessary_transpose(luci::CircleTranspose *node) return false; } - // check perm and shape are CircleConst node and its' datatype is S32 - const auto back_perm = dynamic_cast(back_transpose->perm()); - { - if (back_perm == nullptr) - return false; - - if (back_perm->dtype() != loco::DataType::S32) - return false; - } - const auto front_perm = dynamic_cast(front_transpose->perm()); - { - if (front_perm == nullptr) - return false; - - if (front_perm->dtype() != loco::DataType::S32) - return false; - } + TaggedShapeAnalyzer analyzer; - // for now, handle only rank reduction equal (not expansion) cases - const auto output_rank = back_transpose->rank(); - const auto input_rank = front_transpose->rank(); - if (input_rank < output_rank) + if (not analyzer.init(front_transpose, mid_reshape, back_transpose)) return false; - // analyze pattern to check this pass is applicable - TaggedShapeAnalyzer analyzer; - if (not analyzer.can_remove_transposes(front_transpose, mid_reshape, - back_transpose)) + if (not analyzer.can_remove_transposes()) return false; // repalce with new_node From 038e5a7eb2e3906e6e52454b4a8565008482b3f9 Mon Sep 17 00:00:00 2001 From: SeungHui Youn <61981457+zetwhite@users.noreply.github.com> Date: Thu, 3 Oct 2024 11:12:54 +0900 Subject: [PATCH 2/5] trim comments --- compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp index 5a607309a46..c6d444ab5bb 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -261,7 +261,7 @@ bool TaggedShapeAnalyzer::verify_tag() const } /** - * @brief Initialize the class member and checks under conditions + * @brief Initialize the class members and check under conditions * * Condtiions that have to be met for analyzer * c1: input rank >= output rank @@ -358,7 +358,7 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, */ template bool TaggedShapeAnalyzer::can_remove_transposes() { - // TODO: Update methods to use std::vector intead of CircleNode + // TODO: Update methods to use std::vector intead of CircleNode ptr // For example, // init_shape_with_tag(_in_shape); // analyze_transpose(_fornt_perm); From 706cbf02eceef5929a0fa6b300582bc125c36842 Mon Sep 17 00:00:00 2001 From: sseung Date: Fri, 4 Oct 2024 13:05:13 +0900 Subject: [PATCH 3/5] remove template in extract_const --- .../src/RemoveUnnecessaryTransposeNetPass.cpp | 85 +++++++++++-------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp index c6d444ab5bb..c1b09f437e8 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -28,32 +28,50 @@ namespace if (not(condition)) \ return false; -std::vector extract_shape(const luci::CircleNode *node) +bool extract_shape(const luci::CircleNode *node, std::vector &shape) { - std::vector shape; + uint32_t max_i32 = static_cast(std::numeric_limits::max()); + auto rank = node->rank(); for (auto i = 0u; i < rank; ++i) { - shape.push_back(static_cast(node->dim(i).value())); + uint32_t v = node->dim(i).value(); + RET_FALSE_UNLESS(v <= max_i32) + shape.push_back(static_cast(v)); } - return shape; + return true; }; -template -std::vector extract_const(const luci::CircleConst *const_node) +bool extract_const(const luci::CircleConst *const_node, std::vector &values) { - static_assert(DTYPE == loco::DataType::S32 || DTYPE == loco::DataType::S16 || - DTYPE == loco::DataType::S8, - "unsupported data type"); + auto dtype = const_node->dtype(); - std::vector values; - auto size = const_node->size(); - for (auto i = 0u; i < size; ++i) + if (dtype == loco::DataType::S32) { - auto v = const_node->at(i); - values.push_back(static_cast(v)); + auto size = const_node->size(); + for (auto i = 0u; i < size; ++i) + { + int32_t v = const_node->at(i); + values.push_back(v); + } } - return values; + else if (dtype == loco::DataType::S64) + { + int64_t max_i32 = static_cast(std::numeric_limits::max()); + int64_t min_i32 = static_cast(std::numeric_limits::lowest()); + + auto size = const_node->size(); + for (auto i = 0u; i < size; ++i) + { + int64_t v = const_node->at(i); + RET_FALSE_UNLESS(min_i32 <= v && v <= max_i32); + values.push_back(static_cast(v)); + } + } + else + return false; + + return true; }; struct TagDim final @@ -67,7 +85,6 @@ using TagShape = std::vector; class TaggedShapeAnalyzer final { public: - template bool init(const luci::CircleTranspose *, const luci::CircleReshape *, const luci::CircleTranspose *); @@ -88,10 +105,10 @@ class TaggedShapeAnalyzer final const luci::CircleReshape *_mid_reshape = nullptr; const luci::CircleTranspose *_back_transpose = nullptr; - std::vector _in_shape; - std::vector _front_perm; - std::vector _mid_shape; - std::vector _back_perm; + std::vector _in_shape_v; + std::vector _front_perm_v; + std::vector _mid_shape_v; + std::vector _back_perm_v; const uint8_t START_TAG = 0; TagShape _shape; @@ -266,12 +283,11 @@ bool TaggedShapeAnalyzer::verify_tag() const * Condtiions that have to be met for analyzer * c1: input rank >= output rank * c2: The 'perm' of tranpose should be a CircleConst* type - * c3: The shapes of the given nodes should be all known + * c3: The shapes of input node and reshape node should be known * * @return True, if all conditions are satisfied and class members are initialized successfully * False, otherwise */ -template bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, const luci::CircleReshape *mid_reshape, const luci::CircleTranspose *back_transpose) @@ -289,14 +305,12 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, // check c2 RET_FALSE_UNLESS(front_perm != nullptr); - RET_FALSE_UNLESS(front_perm->dtype() == DType); RET_FALSE_UNLESS(back_perm != nullptr); - RET_FALSE_UNLESS(back_perm->dtype() == DType); - _in_shape = extract_shape(_in); - _front_perm = extract_const(front_perm); - _mid_shape = extract_shape(_mid_reshape); - _back_perm = extract_const(back_perm); + RET_FALSE_UNLESS(extract_shape(_in, _in_shape_v)); + RET_FALSE_UNLESS(extract_const(front_perm, _front_perm_v)); + RET_FALSE_UNLESS(extract_shape(_mid_reshape, _mid_shape_v)); + RET_FALSE_UNLESS(extract_const(back_perm, _back_perm_v)); auto all_known = [](const std::vector &v) -> bool { for (auto i : v) @@ -306,10 +320,8 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, }; // check c3 - RET_FALSE_UNLESS(all_known(_in_shape)); - RET_FALSE_UNLESS(all_known(extract_shape(_front_transpose))); - RET_FALSE_UNLESS(all_known(_mid_shape)); - RET_FALSE_UNLESS(all_known(extract_shape(_back_transpose))); + RET_FALSE_UNLESS(all_known(_in_shape_v)); + RET_FALSE_UNLESS(all_known(_mid_shape_v)); return true; } @@ -358,10 +370,13 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, */ template bool TaggedShapeAnalyzer::can_remove_transposes() { + assert(_in != nullptr || _front_transpose != nullptr || _mid_reshape != nullptr || + _back_transpose != nullptr); + // TODO: Update methods to use std::vector intead of CircleNode ptr // For example, - // init_shape_with_tag(_in_shape); - // analyze_transpose(_fornt_perm); + // init_shape_with_tag(_in_shape_v); + // analyze_transpose(_fornt_perm_v); init_shape_with_tag(_in); @@ -439,7 +454,7 @@ bool remove_unnecessary_transpose(luci::CircleTranspose *node) TaggedShapeAnalyzer analyzer; - if (not analyzer.init(front_transpose, mid_reshape, back_transpose)) + if (not analyzer.init(front_transpose, mid_reshape, back_transpose)) return false; if (not analyzer.can_remove_transposes()) From a46729d59eb22f59535e8b4bfeb8655be2279444 Mon Sep 17 00:00:00 2001 From: sseung Date: Fri, 4 Oct 2024 13:07:58 +0900 Subject: [PATCH 4/5] trim --- .../src/RemoveUnnecessaryTransposeNetPass.cpp | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp index c1b09f437e8..601a4d08a7d 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -24,8 +24,8 @@ namespace { -#define RET_FALSE_UNLESS(condition) \ - if (not(condition)) \ +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ return false; bool extract_shape(const luci::CircleNode *node, std::vector &shape) @@ -36,7 +36,7 @@ bool extract_shape(const luci::CircleNode *node, std::vector &shape) for (auto i = 0u; i < rank; ++i) { uint32_t v = node->dim(i).value(); - RET_FALSE_UNLESS(v <= max_i32) + CHECK_OR_FALSE(v <= max_i32) shape.push_back(static_cast(v)); } return true; @@ -64,7 +64,7 @@ bool extract_const(const luci::CircleConst *const_node, std::vector &va for (auto i = 0u; i < size; ++i) { int64_t v = const_node->at(i); - RET_FALSE_UNLESS(min_i32 <= v && v <= max_i32); + CHECK_OR_FALSE(min_i32 <= v && v <= max_i32); values.push_back(static_cast(v)); } } @@ -283,7 +283,7 @@ bool TaggedShapeAnalyzer::verify_tag() const * Condtiions that have to be met for analyzer * c1: input rank >= output rank * c2: The 'perm' of tranpose should be a CircleConst* type - * c3: The shapes of input node and reshape node should be known + * c3: The input shape and the reshape node's shape should be known * * @return True, if all conditions are satisfied and class members are initialized successfully * False, otherwise @@ -292,25 +292,25 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, const luci::CircleReshape *mid_reshape, const luci::CircleTranspose *back_transpose) { - _in = dynamic_cast(front_transpose->a()); + _in = loco::must_cast(front_transpose->a()); _front_transpose = front_transpose; _mid_reshape = mid_reshape; _back_transpose = back_transpose; // check c1 - RET_FALSE_UNLESS(_in->rank() >= _back_transpose->rank()); + CHECK_OR_FALSE(_in->rank() >= _back_transpose->rank()); const auto front_perm = dynamic_cast(_front_transpose->perm()); const auto back_perm = dynamic_cast(_back_transpose->perm()); // check c2 - RET_FALSE_UNLESS(front_perm != nullptr); - RET_FALSE_UNLESS(back_perm != nullptr); + CHECK_OR_FALSE(front_perm != nullptr); + CHECK_OR_FALSE(back_perm != nullptr); - RET_FALSE_UNLESS(extract_shape(_in, _in_shape_v)); - RET_FALSE_UNLESS(extract_const(front_perm, _front_perm_v)); - RET_FALSE_UNLESS(extract_shape(_mid_reshape, _mid_shape_v)); - RET_FALSE_UNLESS(extract_const(back_perm, _back_perm_v)); + CHECK_OR_FALSE(extract_shape(_in, _in_shape_v)); + CHECK_OR_FALSE(extract_const(front_perm, _front_perm_v)); + CHECK_OR_FALSE(extract_shape(_mid_reshape, _mid_shape_v)); + CHECK_OR_FALSE(extract_const(back_perm, _back_perm_v)); auto all_known = [](const std::vector &v) -> bool { for (auto i : v) @@ -320,8 +320,8 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, }; // check c3 - RET_FALSE_UNLESS(all_known(_in_shape_v)); - RET_FALSE_UNLESS(all_known(_mid_shape_v)); + CHECK_OR_FALSE(all_known(_in_shape_v)); + CHECK_OR_FALSE(all_known(_mid_shape_v)); return true; } @@ -370,14 +370,10 @@ bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, */ template bool TaggedShapeAnalyzer::can_remove_transposes() { - assert(_in != nullptr || _front_transpose != nullptr || _mid_reshape != nullptr || + assert(_in != nullptr && _front_transpose != nullptr && _mid_reshape != nullptr && _back_transpose != nullptr); - // TODO: Update methods to use std::vector intead of CircleNode ptr - // For example, - // init_shape_with_tag(_in_shape_v); - // analyze_transpose(_fornt_perm_v); - + // TODO: Update under methods to use std::vector intead of CircleNode* init_shape_with_tag(_in); analyze_transpose(_front_transpose); From 1d207d8a2f255a7c6eea17d6ddc4afc31b2dbe44 Mon Sep 17 00:00:00 2001 From: sseung Date: Fri, 4 Oct 2024 13:52:29 +0900 Subject: [PATCH 5/5] include limits header --- compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp index 601a4d08a7d..006ede6fd1f 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -19,6 +19,7 @@ #include #include +#include #include namespace