From 07acbcb9165d03e18e0da73b56d907ee2509d9e6 Mon Sep 17 00:00:00 2001 From: Zheng-Bicheng <58363586+Zheng-Bicheng@users.noreply.github.com> Date: Sat, 24 Aug 2024 12:36:11 +0800 Subject: [PATCH] Fixed matmul bug when export to rknn (#1355) * remove matmul when export to rknn * update * update * update * Delete tests/test_auto_scan_hardswish.py * Update test_hardswish.py * add Slice --- .gitignore | 3 +- VERSION_NUMBER | 2 +- paddle2onnx/mapper/activation/hard_swish.cc | 24 ++- paddle2onnx/mapper/activation/hard_swish.h | 2 + paddle2onnx/mapper/quantize_helper.cc | 178 +++++++++----------- paddle2onnx/mapper/quantize_helper.h | 4 +- tests/test_auto_scan_hardswish.py | 67 -------- tests/test_hardswish.py | 17 +- 8 files changed, 114 insertions(+), 183 deletions(-) delete mode 100644 tests/test_auto_scan_hardswish.py diff --git a/.gitignore b/.gitignore index 707a08b20..ed499725c 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ protobuf-* *.temptxt tests/__pycache_* tests/*/__pycache_* -tests/*/*.info \ No newline at end of file +tests/*/*.info +tools/onnx/__pycache__* diff --git a/VERSION_NUMBER b/VERSION_NUMBER index c04c650a7..db6fb4a91 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.2.7 +1.2.8 diff --git a/paddle2onnx/mapper/activation/hard_swish.cc b/paddle2onnx/mapper/activation/hard_swish.cc index 84f9986ed..38579020d 100644 --- a/paddle2onnx/mapper/activation/hard_swish.cc +++ b/paddle2onnx/mapper/activation/hard_swish.cc @@ -17,6 +17,11 @@ namespace paddle2onnx { REGISTER_MAPPER(hard_swish, HardSwishMapper) +int32_t HardSwishMapper::GetMinOpsetVersion(bool verbose) { + Logger(verbose, 14) << RequireOpset(14) << std::endl; + return 14; +} + void HardSwishMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); @@ -29,9 +34,24 @@ void HardSwishMapper::Opset7() { helper_->MakeNode("Div", {mul_node->output(0), scale_node}, {output_info[0].name}); } +inline bool IsAlmostEqual(float a, float b) { + constexpr float epsilon = 1e-5f; + return std::fabs(a - b) < epsilon; +} + void HardSwishMapper::Opset14() { - if (fabs(offset_ - 3.0) > 1e-05 || fabs(scale_ - 6.0) > 1e-05 || - fabs(threshold_ - 6.0) > 1e-05) { + if (!IsAlmostEqual(offset_, 3.0)) { + P2OLogger() << "offset != 3.0, using Opset7()" << std::endl; + return Opset7(); + } + + if (!IsAlmostEqual(scale_, 6.0)) { + P2OLogger() << "scale_ != 6.0, using Opset7()" << std::endl; + return Opset7(); + } + + if (!IsAlmostEqual(threshold_, 6.0)) { + P2OLogger() << "offset != 3.0, using Opset7()" << std::endl; return Opset7(); } auto input_info = GetInput("X"); diff --git a/paddle2onnx/mapper/activation/hard_swish.h b/paddle2onnx/mapper/activation/hard_swish.h index abaa427f1..07610c969 100644 --- a/paddle2onnx/mapper/activation/hard_swish.h +++ b/paddle2onnx/mapper/activation/hard_swish.h @@ -32,6 +32,8 @@ class HardSwishMapper : public Mapper { GetAttr("threshold", &threshold_); } + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset7() override; void Opset14() override; diff --git a/paddle2onnx/mapper/quantize_helper.cc b/paddle2onnx/mapper/quantize_helper.cc index 3db7395bb..b47282364 100644 --- a/paddle2onnx/mapper/quantize_helper.cc +++ b/paddle2onnx/mapper/quantize_helper.cc @@ -14,8 +14,9 @@ #include "paddle2onnx/mapper/quantize_helper.h" -namespace paddle2onnx { +#include +namespace paddle2onnx { void QuantizeModelProcessor::RemoveNodeByName(const std::string& name, const bool& update_io) { if (name.empty()) { @@ -168,15 +169,12 @@ void QuantizeModelProcessor::ProcessQuantizeModel( // When deploy_backend is RKNN, use the follow four steps to process: // 1. broadcast quantize info // 2. remove all quantize ops - // 3. merge conv and add - // 4. merge conv and bn - // 5. add Q and DQ - // 6. use topo sort in nodes + // 3. add Q and DQ + // 4. use topo sort in nodes QuantizeInfoBroadcast(); RemoveAllQuantizeOps(); RemoveIdentityOp(); MergeConvAdd(); - MergeConvBN(); AddQDQForRKNN(); SortNodes(); } else { @@ -217,6 +215,7 @@ void QuantizeModelProcessor::AddQDQForRKNN() { "Cos", "Cosh", "Concat", + "Div", "Elu", "Erf", "Exp", @@ -224,85 +223,96 @@ void QuantizeModelProcessor::AddQDQForRKNN() { "Gemm", "GlobalAveragePool", "HardSigmoid", + "HardSwish", "InstanceNormalization", "IsInf", "IsNaN", "Log", + "MatMul", "MaxPool", "Mul", "Neg", "ReduceMean", "Relu", + "Reshape", "Resize", "Round", "Sigmoid", "Sin", "Sinh", + "Slice", + "Softmax", "Split", "Sqrt", "Tan", - "MatMul", - "Tanh"}; + "Tanh", + "Transpose"}; for (auto iter = nodes_->begin(); iter < nodes_->end(); iter++) { auto node = *iter; - auto type_iter = std::find(supported_quantize_type_.begin(), - supported_quantize_type_.end(), node->op_type()); - if (!supported_quantize_type_.empty() && - type_iter == supported_quantize_type_.end()) { + auto type_iter = std::find(supported_quantize_type_.begin(), supported_quantize_type_.end(), node->op_type()); + if (!supported_quantize_type_.empty() && type_iter == supported_quantize_type_.end()) { continue; } - if (node->op_type() == "MatMul") { - std::vector tensor_names = {node->input(0), node->input(1), - node->output(0)}; + std::vector tensor_names = {}; + for (size_t i = 0; i < node->input_size(); ++i) { + std::string node_input = node->input(i); + tensor_names.push_back(node_input); + } + for (size_t i = 0; i < node->output_size(); ++i) { + std::string node_output = node->output(i); + tensor_names.push_back(node_output); + } + + if (node->op_type() == "MatMul" || node->op_type() == "Add" || node->op_type() == "Mul") { for (auto& name : tensor_names) { if (helper_->quantize_info.find(name) != helper_->quantize_info.end()) { continue; } + std::vector matmul_weight; if (!GetTensorByName(name, &matmul_weight)) { + P2OLogger() << "Failed to GetTensorByName: " << node->op_type() << ";" << name << std::endl; continue; } + std::vector matmul_weight_shape; if (!GetTensorShape(name, &matmul_weight_shape)) { + P2OLogger() << "Failed to GetTensorShape: " << node->op_type() << ";" << name << std::endl; continue; } + int64_t quantize_axis = 1; std::vector scale; std::vector zeros; - GetChannelWiseQuantizeInfo(matmul_weight, matmul_weight_shape, - quantize_axis, &scale, &zeros); - auto scale_node = - helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale); - auto zero_node = - helper_->Constant(ONNX_NAMESPACE::TensorProto::INT8, zeros); - QuantizeInfo matmul_weight_quantize_info(scale, zeros, scale_node, - zero_node, quantize_axis); - helper_->quantize_info[name] = matmul_weight_quantize_info; - } - if (!CanBeQuantize(tensor_names)) { - tensor_names.pop_back(); - if (!CanBeQuantize(tensor_names)) { - continue; + if(matmul_weight_shape.size() == 1) { + quantize_axis = 0; } + GetChannelWiseQuantizeInfo(matmul_weight, matmul_weight_shape, quantize_axis, &scale, &zeros); + std::string scale_node, zero_node; + + if (scale.size() == 1) { + scale_node = helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale[0]); + zero_node = helper_->Constant({}, ONNX_NAMESPACE::TensorProto::INT8, zeros[0]); + } else { + scale_node = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale); + zero_node = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT8, zeros); + } + QuantizeInfo matmul_weight_quantize_info(scale, zeros, scale_node, zero_node, quantize_axis); + helper_->quantize_info[name] = matmul_weight_quantize_info; } - for (auto& name : tensor_names) { - AppendQuantizeTensor(name); - } - } - - std::vector tensor_names; - for (size_t i = 0; i < node->input_size(); ++i) { - std::string node_input = node->input(i); - tensor_names.push_back(node_input); - } - for (size_t i = 0; i < node->output_size(); ++i) { - std::string node_output = node->output(i); - tensor_names.push_back(node_output); + } else if (node->op_type() == "BatchNormalization") { + // BatchNormalization only need quntize X and Y. + // when opset > 9, tensor_names is {X, scale, B, input_mean, input_var, Y, running_mean, running_var} + // when opset <= 9, tensor_names is {X, scale, B, mean, var, Y, mean, var, saved_mean, saved_var} + tensor_names.erase(tensor_names.begin() + 1, tensor_names.begin() + 5); + tensor_names.erase(tensor_names.begin() + 2, tensor_names.end()); } + if (!CanBeQuantize(tensor_names)) { continue; } + for (auto& name : tensor_names) { AppendQuantizeTensor(name); } @@ -754,16 +764,13 @@ void QuantizeModelProcessor::MergeConvAdd() { continue; } // if act input of conv does not have quantize info, continue - bool act_has_quantize_info = helper_->quantize_info.find(node->input(0)) != - helper_->quantize_info.end(); + bool act_has_quantize_info = helper_->quantize_info.find(node->input(0)) != helper_->quantize_info.end(); if (!act_has_quantize_info) { continue; } // if weight of conv does not have quantize info, continue - bool weight_has_quantize_info = - helper_->quantize_info.find(node->input(1)) != - helper_->quantize_info.end(); + bool weight_has_quantize_info = helper_->quantize_info.find(node->input(1)) != helper_->quantize_info.end(); if (!weight_has_quantize_info) { continue; } @@ -808,18 +815,15 @@ void QuantizeModelProcessor::MergeConvAdd() { continue; } // continue if shape_val != [1, bias_val.size(), 1, 1] - std::vector target = {1, static_cast(bias_val.size()), 1, - 1}; + std::vector target = {1, static_cast(bias_val.size()), 1, 1}; if (target != shape_val) { continue; } // remove Reshape op RemoveNodeByName(before_nodes[0]->name()); // add scale for bias - std::vector weight_scale = - helper_->quantize_info[node->input(1)].scale_; - std::vector act_scale = - helper_->quantize_info[node->input(0)].scale_; + std::vector weight_scale = helper_->quantize_info[node->input(1)].scale_; + std::vector act_scale = helper_->quantize_info[node->input(0)].scale_; std::vector bias_scale; for (int64_t i = 0; i < weight_scale.size(); i++) { bias_scale.push_back(weight_scale[i] * act_scale[0]); @@ -830,8 +834,7 @@ void QuantizeModelProcessor::MergeConvAdd() { auto zero_node = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT32, onnx_zeros); - QuantizeInfo quantize_info(bias_scale, onnx_zeros, scale_node, zero_node, - 0); + QuantizeInfo quantize_info(bias_scale, onnx_zeros, scale_node, zero_node, 0); helper_->quantize_info[bias_node] = quantize_info; AppendQuantizeTensor(bias_node, true); @@ -1048,18 +1051,21 @@ void QuantizeModelProcessor::GetTensorWiseQuantizeInfo( zero->push_back(0); } -void QuantizeModelProcessor::GetChannelWiseQuantizeInfo( - const std::vector& tensor, const std::vector& shape, - const int64_t& quant_axis, std::vector* scale, - std::vector* zero) { - int64_t channel_count = shape[quant_axis]; - - for (int64_t i = 0; i < channel_count; i++) { - if (quant_axis == 0) { +void QuantizeModelProcessor::GetChannelWiseQuantizeInfo(const std::vector& tensor, + const std::vector& shapes, + int64_t quant_axis, + std::vector* scale, + std::vector* zero) { + int64_t channel_count = 1; + if (shapes.size() != 1) { + quant_axis = 1; + } + if (quant_axis == 0) { + for (int64_t i = 0; i < channel_count; i++) { float max_val = -1; int64_t inner_offset = 1; - for (auto& j : shape) { - inner_offset *= j; + for (auto& shape : shapes) { + inner_offset *= shape; } inner_offset /= channel_count; int64_t index = i * inner_offset; @@ -1068,36 +1074,19 @@ void QuantizeModelProcessor::GetChannelWiseQuantizeInfo( max_val = fabs(tensor[index + j]); } } - Assert( - max_val >= 0, - "[GetChannelWiseQuantizeInfo] Require the scale >= 0, but now it's " + - std::to_string(max_val) + "."); - scale->push_back(max_val / 127); - zero->push_back(0); - } else if (quant_axis == 1) { - float max_val = -1; - int64_t inner_offset = shape.size() == 4 ? shape[2] * shape[3] : 1; - for (int64_t outter = 0; outter < shape[0]; outter++) { - int64_t index = outter * channel_count * inner_offset; - for (int64_t inner = 0; inner < inner_offset; inner++) { - int64_t final_index = index + i * inner_offset + inner; - if (fabs(tensor[final_index]) > max_val) { - max_val = fabs(tensor[final_index]); - } - } - } - Assert( - max_val >= 0, - "[GetChannelWiseQuantizeInfo] Require the scale >= 0, but now it's " + - std::to_string(max_val) + "."); + Assert(max_val >= 0, "[GetChannelWiseQuantizeInfo] Require the scale >= 0, but now it's " + std::to_string(max_val) + "."); scale->push_back(max_val / 127); zero->push_back(0); - } else { - Assert(false, - "QuantizeModelProcessor::GetChannelWiseQuantizeInfo only supports " - "quant_axis equals to 0 or 1, but now it's " + - std::to_string(quant_axis) + "."); } + } else if (quant_axis == 1) { + auto max_val = *std::max_element(tensor.begin(), tensor.end()); + Assert(max_val >= 0, "[GetChannelWiseQuantizeInfo] Require the scale >= 0, but now it's " + std::to_string(max_val) + "."); + scale->push_back(max_val / 127); + zero->push_back(0); + } else { + Assert(false, + "QuantizeModelProcessor::GetChannelWiseQuantizeInfo only supports quant_axis equals to 0, 1, -1, " + "but now it's " + std::to_string(quant_axis) + "."); } } @@ -1149,8 +1138,8 @@ bool QuantizeModelProcessor::CanBeQuantize( return false; } } - // If there is an OP linked to the output by identity, it needs to be skipped, - // do not quantize the OP + + // If there is an OP linked to the output by identity, it needs to be skipped, do not quantize the OP for (auto i = 0; i < output_index.size(); i++) { int64_t index = output_index[i]; if (index == -1) { @@ -1159,6 +1148,7 @@ bool QuantizeModelProcessor::CanBeQuantize( std::string output_name = tensor_names[index]; if (ConnectToOutput(output_name)) { + P2OLogger() << "ConnectToOutput: " << output_name << std::endl; return false; } } diff --git a/paddle2onnx/mapper/quantize_helper.h b/paddle2onnx/mapper/quantize_helper.h index 59d76d150..b4cdde039 100755 --- a/paddle2onnx/mapper/quantize_helper.h +++ b/paddle2onnx/mapper/quantize_helper.h @@ -117,8 +117,8 @@ struct QuantizeModelProcessor { // Perform channel wise quantization, returning scale and zero void GetChannelWiseQuantizeInfo(const std::vector& tensor, - const std::vector& shape, - const int64_t& quant_axis, + const std::vector& shapes, + int64_t quant_axis, std::vector* scale, std::vector* zero); diff --git a/tests/test_auto_scan_hardswish.py b/tests/test_auto_scan_hardswish.py deleted file mode 100644 index 8d1656eaf..000000000 --- a/tests/test_auto_scan_hardswish.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from auto_scan_test import OPConvertAutoScanTest, BaseNet -from hypothesis import reproduce_failure -import hypothesis.strategies as st -import numpy as np -import unittest -import paddle - - -class Net(BaseNet): - """ - simple Net - """ - - def forward(self, inputs): - """ - forward - """ - x = paddle.nn.functional.hardswish(inputs) - return x - - -class TestHardswishConvert(OPConvertAutoScanTest): - """ - api: paddle.nn.functional.hardswish - OPset version: 7, 14 - """ - - def sample_convert_config(self, draw): - input_shape = draw( - st.lists( - st.integers( - min_value=10, max_value=20), min_size=0, max_size=4)) - - dtype = draw(st.sampled_from(["float32"])) - - config = { - "op_names": ["hard_swish"], - "test_data_shapes": [input_shape], - "test_data_types": [[dtype]], - "opset_version": [7, 14], - "input_spec_shape": [], - } - - models = Net(config) - - return (config, models) - - def test(self): - self.run_and_statis(max_examples=30) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_hardswish.py b/tests/test_hardswish.py index 5e7413436..484cfee25 100644 --- a/tests/test_hardswish.py +++ b/tests/test_hardswish.py @@ -33,21 +33,6 @@ def forward(self, inputs): return x -def test_hardswish_7(): - """ - api: paddle.hardswish - op version: 7 - """ - op = Net() - op.eval() - # net, name, ver_list, delta=1e-6, rtol=1e-5 - obj = APIOnnx(op, 'hardswish', [7]) - obj.set_input_data( - "input_data", - paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32'))) - obj.run() - - def test_hardswish_14(): """ api: paddle.hardswish @@ -60,4 +45,4 @@ def test_hardswish_14(): obj.set_input_data( "input_data", paddle.to_tensor(randtool("float", -1, 1, [3, 10]).astype('float32'))) - obj.run() \ No newline at end of file + obj.run()