Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Add TaggedShapeAnalyzer::init() in RmUnnTransNetPass #14152

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 128 additions & 40 deletions compiler/luci/pass/src/RemoveUnnecessaryTransposeNetPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,62 @@
#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>

#include <limits>
#include <vector>

namespace
{

#define CHECK_OR_FALSE(condition) \
if (not(condition)) \
return false;

bool extract_shape(const luci::CircleNode *node, std::vector<int32_t> &shape)
{
uint32_t max_i32 = static_cast<uint32_t>(std::numeric_limits<int32_t>::max());

auto rank = node->rank();
for (auto i = 0u; i < rank; ++i)
{
uint32_t v = node->dim(i).value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay to not consider the 64-bit range?

Suggested change
uint32_t v = node->dim(i).value();
uint64_t v = node->dim(i).value();

Copy link
Contributor Author

@zetwhite zetwhite Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shs-park Thank you for the review!

uint32_t value(void) const { return _value; }

loco::Dimension only holds uint32_t values. So we don't need to handle int64_t for this function.

CHECK_OR_FALSE(v <= max_i32)
shape.push_back(static_cast<int32_t>(v));
}
return true;
};

bool extract_const(const luci::CircleConst *const_node, std::vector<int32_t> &values)
{
auto dtype = const_node->dtype();

if (dtype == loco::DataType::S32)
{
auto size = const_node->size<loco::DataType::S32>();
for (auto i = 0u; i < size; ++i)
{
int32_t v = const_node->at<loco::DataType::S32>(i);
values.push_back(v);
}
}
else if (dtype == loco::DataType::S64)
{
int64_t max_i32 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add #include <limits> to fix build error.

int64_t min_i32 = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());

auto size = const_node->size<loco::DataType::S64>();
for (auto i = 0u; i < size; ++i)
{
int64_t v = const_node->at<loco::DataType::S64>(i);
CHECK_OR_FALSE(min_i32 <= v && v <= max_i32);
values.push_back(static_cast<int32_t>(v));
}
}
else
return false;

return true;
};

struct TagDim final
{
int32_t value;
Expand All @@ -35,9 +86,10 @@ using TagShape = std::vector<TagDim>;
class TaggedShapeAnalyzer final
{
public:
template <loco::DataType DType>
bool can_remove_transposes(const luci::CircleTranspose *f_tr, const luci::CircleReshape *m_rs,
const luci::CircleTranspose *b_tr);
bool init(const luci::CircleTranspose *, const luci::CircleReshape *,
const luci::CircleTranspose *);

template <loco::DataType DType> bool can_remove_transposes();

private:
void init_shape_with_tag(const luci::CircleNode *);
Expand All @@ -48,6 +100,17 @@ class TaggedShapeAnalyzer final

bool verify_tag() const;

private:
const luci::CircleNode *_in = nullptr;
const luci::CircleTranspose *_front_transpose = nullptr;
const luci::CircleReshape *_mid_reshape = nullptr;
const luci::CircleTranspose *_back_transpose = nullptr;

std::vector<int32_t> _in_shape_v;
std::vector<int32_t> _front_perm_v;
std::vector<int32_t> _mid_shape_v;
std::vector<int32_t> _back_perm_v;

const uint8_t START_TAG = 0;
TagShape _shape;
};
Expand Down Expand Up @@ -215,9 +278,60 @@ bool TaggedShapeAnalyzer::verify_tag() const
return true;
}

/**
* @brief Initialize the class members and check under conditions
*
* Condtiions that have to be met for analyzer
* c1: input rank >= output rank
* c2: The 'perm' of tranpose should be a CircleConst* type
* c3: The input shape and the reshape node's shape should be known
*
* @return True, if all conditions are satisfied and class members are initialized successfully
* False, otherwise
*/
bool TaggedShapeAnalyzer::init(const luci::CircleTranspose *front_transpose,
const luci::CircleReshape *mid_reshape,
const luci::CircleTranspose *back_transpose)
{
_in = loco::must_cast<luci::CircleNode *>(front_transpose->a());
_front_transpose = front_transpose;
_mid_reshape = mid_reshape;
_back_transpose = back_transpose;

// check c1
CHECK_OR_FALSE(_in->rank() >= _back_transpose->rank());

const auto front_perm = dynamic_cast<luci::CircleConst *>(_front_transpose->perm());
const auto back_perm = dynamic_cast<luci::CircleConst *>(_back_transpose->perm());

// check c2
CHECK_OR_FALSE(front_perm != nullptr);
CHECK_OR_FALSE(back_perm != nullptr);

CHECK_OR_FALSE(extract_shape(_in, _in_shape_v));
CHECK_OR_FALSE(extract_const(front_perm, _front_perm_v));
CHECK_OR_FALSE(extract_shape(_mid_reshape, _mid_shape_v));
CHECK_OR_FALSE(extract_const(back_perm, _back_perm_v));

auto all_known = [](const std::vector<int32_t> &v) -> bool {
for (auto i : v)
if (i <= 0)
return false;
return true;
};

// check c3
CHECK_OR_FALSE(all_known(_in_shape_v));
CHECK_OR_FALSE(all_known(_mid_shape_v));

return true;
}

/**
* @brief check 'Transpose-Reshape-Transpose' can be replaced by one 'Reshape'.
*
* @warning '@init' have to be called first
zetwhite marked this conversation as resolved.
Show resolved Hide resolved
*
* @example
* Let's explain how analyzer check Transpose-Reshape-Transpose pattern with an exact example.
*
Expand Down Expand Up @@ -255,24 +369,20 @@ bool TaggedShapeAnalyzer::verify_tag() const
* Transpose has no effect in final shape, which they can be removed as
* unnecessary Ops.
*/
template <loco::DataType DType>
bool TaggedShapeAnalyzer::can_remove_transposes(const luci::CircleTranspose *f_tr,
const luci::CircleReshape *m_rs,
const luci::CircleTranspose *b_tr)
template <loco::DataType DType> bool TaggedShapeAnalyzer::can_remove_transposes()
{
assert(loco::must_cast<luci::CircleConst *>(f_tr->perm())->dtype() == DType);
assert(loco::must_cast<luci::CircleConst *>(b_tr->perm())->dtype() == DType);
assert(_in != nullptr && _front_transpose != nullptr && _mid_reshape != nullptr &&
_back_transpose != nullptr);

const luci::CircleNode *in_tensor = loco::must_cast<luci::CircleNode *>(f_tr->a());
// TODO: Update under methods to use std::vector<int32_t&> intead of CircleNode*
init_shape_with_tag(_in);

init_shape_with_tag(in_tensor);
analyze_transpose<DType>(_front_transpose);

analyze_transpose<DType>(f_tr);

if (not analyze_reshape(m_rs))
if (not analyze_reshape(_mid_reshape))
return false;

analyze_transpose<DType>(b_tr);
analyze_transpose<DType>(_back_transpose);

if (not verify_tag())
return false;
Expand Down Expand Up @@ -339,34 +449,12 @@ bool remove_unnecessary_transpose(luci::CircleTranspose *node)
return false;
}

// check perm and shape are CircleConst node and its' datatype is S32
const auto back_perm = dynamic_cast<luci::CircleConst *>(back_transpose->perm());
{
if (back_perm == nullptr)
return false;

if (back_perm->dtype() != loco::DataType::S32)
return false;
}
const auto front_perm = dynamic_cast<luci::CircleConst *>(front_transpose->perm());
{
if (front_perm == nullptr)
return false;

if (front_perm->dtype() != loco::DataType::S32)
return false;
}
TaggedShapeAnalyzer analyzer;

// for now, handle only rank reduction equal (not expansion) cases
const auto output_rank = back_transpose->rank();
const auto input_rank = front_transpose->rank();
if (input_rank < output_rank)
if (not analyzer.init(front_transpose, mid_reshape, back_transpose))
return false;

// analyze pattern to check this pass is applicable
TaggedShapeAnalyzer analyzer;
if (not analyzer.can_remove_transposes<loco::DataType::S32>(front_transpose, mid_reshape,
back_transpose))
if (not analyzer.can_remove_transposes<loco::DataType::S32>())
return false;

// repalce with new_node
Expand Down