diff --git a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp index f07d4fa18eb..006ede6fd1f 100644 --- a/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp +++ b/compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp @@ -19,11 +19,62 @@ #include #include +#include #include namespace { +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +bool extract_shape(const luci::CircleNode *node, std::vector &shape) +{ + uint32_t max_i32 = static_cast(std::numeric_limits::max()); + + auto rank = node->rank(); + for (auto i = 0u; i < rank; ++i) + { + uint32_t v = node->dim(i).value(); + CHECK_OR_FALSE(v <= max_i32) + shape.push_back(static_cast(v)); + } + return true; +}; + +bool extract_const(const luci::CircleConst *const_node, std::vector &values) +{ + auto dtype = const_node->dtype(); + + if (dtype == loco::DataType::S32) + { + auto size = const_node->size(); + for (auto i = 0u; i < size; ++i) + { + int32_t v = const_node->at(i); + values.push_back(v); + } + } + else if (dtype == loco::DataType::S64) + { + int64_t max_i32 = static_cast(std::numeric_limits::max()); + int64_t min_i32 = static_cast(std::numeric_limits::lowest()); + + auto size = const_node->size(); + for (auto i = 0u; i < size; ++i) + { + int64_t v = const_node->at(i); + CHECK_OR_FALSE(min_i32 <= v && v <= max_i32); + values.push_back(static_cast(v)); + } + } + else + return false; + + return true; +}; + struct TagDim final { int32_t value; @@ -35,9 +86,10 @@ using TagShape = std::vector; class TaggedShapeAnalyzer final { public: - template - bool can_remove_transposes(const luci::CircleTranspose *f_tr, const luci::CircleReshape *m_rs, - const luci::CircleTranspose *b_tr); + bool init(const luci::CircleTranspose *, const luci::CircleReshape *, + const luci::CircleTranspose *); + + template bool can_remove_transposes(); private: void init_shape_with_tag(const luci::CircleNode *); @@ -48,6 +100,17 @@ class TaggedShapeAnalyzer final bool verify_tag() const; +private: + const luci::CircleNode *_in = nullptr; + const luci::CircleTranspose *_front_transpose = nullptr; + const luci::CircleReshape *_mid_reshape = nullptr; + const luci::CircleTranspose *_back_transpose = nullptr; + + std::vector _in_shape_v; + std::vector _front_perm_v; + std::vector _mid_shape_v; + std::vector _back_perm_v; + const uint8_t START_TAG = 0; TagShape _shape; }; @@ -215,9 +278,60 @@ bool TaggedShapeAnalyzer::verify_tag() const return true; } +/** + * @brief Initialize the class members and check under conditions + * + * Condtiions that have to be met for analyzer + * c1: input rank >= output rank + * c2: The 'perm' of tranpose should be a CircleConst* type + * c3: The input shape and the reshape node's shape should be known + * + * @return True, if all conditions are satisfied and class members are initialized successfully + * False, otherwise + */ +bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose, + const luci::CircleReshape *mid_reshape, + const luci::CircleTranspose *back_transpose) +{ + _in = loco::must_cast(front_transpose->a()); + _front_transpose = front_transpose; + _mid_reshape = mid_reshape; + _back_transpose = back_transpose; + + // check c1 + CHECK_OR_FALSE(_in->rank() >= _back_transpose->rank()); + + const auto front_perm = dynamic_cast(_front_transpose->perm()); + const auto back_perm = dynamic_cast(_back_transpose->perm()); + + // check c2 + CHECK_OR_FALSE(front_perm != nullptr); + CHECK_OR_FALSE(back_perm != nullptr); + + CHECK_OR_FALSE(extract_shape(_in, _in_shape_v)); + CHECK_OR_FALSE(extract_const(front_perm, _front_perm_v)); + CHECK_OR_FALSE(extract_shape(_mid_reshape, _mid_shape_v)); + CHECK_OR_FALSE(extract_const(back_perm, _back_perm_v)); + + auto all_known = [](const std::vector &v) -> bool { + for (auto i : v) + if (i <= 0) + return false; + return true; + }; + + // check c3 + CHECK_OR_FALSE(all_known(_in_shape_v)); + CHECK_OR_FALSE(all_known(_mid_shape_v)); + + return true; +} + /** * @brief check 'Transpose-Reshape-Transpose' can be replaced by one 'Reshape'. * + * @warning '@init' have to be called first + * * @example * Let's explain how analyzer check Transpose-Reshape-Transpose pattern with an exact example. * @@ -255,24 +369,20 @@ bool TaggedShapeAnalyzer::verify_tag() const * Transpose has no effect in final shape, which they can be removed as * unnecessary Ops. */ -template -bool TaggedShapeAnalyzer::can_remove_transposes(const luci::CircleTranspose *f_tr, - const luci::CircleReshape *m_rs, - const luci::CircleTranspose *b_tr) +template bool TaggedShapeAnalyzer::can_remove_transposes() { - assert(loco::must_cast(f_tr->perm())->dtype() == DType); - assert(loco::must_cast(b_tr->perm())->dtype() == DType); + assert(_in != nullptr && _front_transpose != nullptr && _mid_reshape != nullptr && + _back_transpose != nullptr); - const luci::CircleNode *in_tensor = loco::must_cast(f_tr->a()); + // TODO: Update under methods to use std::vector intead of CircleNode* + init_shape_with_tag(_in); - init_shape_with_tag(in_tensor); + analyze_transpose(_front_transpose); - analyze_transpose(f_tr); - - if (not analyze_reshape(m_rs)) + if (not analyze_reshape(_mid_reshape)) return false; - analyze_transpose(b_tr); + analyze_transpose(_back_transpose); if (not verify_tag()) return false; @@ -339,34 +449,12 @@ bool remove_unnecessary_transpose(luci::CircleTranspose *node) return false; } - // check perm and shape are CircleConst node and its' datatype is S32 - const auto back_perm = dynamic_cast(back_transpose->perm()); - { - if (back_perm == nullptr) - return false; - - if (back_perm->dtype() != loco::DataType::S32) - return false; - } - const auto front_perm = dynamic_cast(front_transpose->perm()); - { - if (front_perm == nullptr) - return false; - - if (front_perm->dtype() != loco::DataType::S32) - return false; - } + TaggedShapeAnalyzer analyzer; - // for now, handle only rank reduction equal (not expansion) cases - const auto output_rank = back_transpose->rank(); - const auto input_rank = front_transpose->rank(); - if (input_rank < output_rank) + if (not analyzer.init(front_transpose, mid_reshape, back_transpose)) return false; - // analyze pattern to check this pass is applicable - TaggedShapeAnalyzer analyzer; - if (not analyzer.can_remove_transposes(front_transpose, mid_reshape, - back_transpose)) + if (not analyzer.can_remove_transposes()) return false; // repalce with new_node