Skip to content

Commit

Permalink
Fixed MobileNetV3 Quantized Model for RKNN (#1390)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
Zheng-Bicheng authored Sep 20, 2024
1 parent 2f9ad56 commit 48fa34e
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions paddle2onnx/mapper/quantize/rknn_quantize_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,33 @@ void RKNNQuantizeProcessor::AddQDQ() {
if (helper_->quantize_info.find(name) != helper_->quantize_info.end()) {
continue;
}
std::vector<float> matmul_weight;
if (!GetTensorByName(name, &matmul_weight)) {
std::vector<float> weight_data;
if (!GetTensorByName(name, &weight_data)) {
continue;
}
std::vector<int64_t> matmul_weight_shape;
if (!GetTensorShape(name, &matmul_weight_shape)) {
std::vector<int64_t> weight_shape;
if (!GetTensorShape(name, &weight_shape)) {
continue;
}

int64_t quantize_axis = 1;
std::vector<float> scale;
std::vector<int64_t> zeros;
GetChannelWiseQuantizeInfo(matmul_weight, matmul_weight_shape,
quantize_axis, &scale, &zeros);
auto scale_node =
helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale);
auto zero_node =
helper_->Constant(ONNX_NAMESPACE::TensorProto::INT8, zeros);
std::string scale_node, zero_node;
if (weight_shape.size() <= 1) {
GetTensorWiseQuantizeInfo(weight_data, &scale, &zeros);
scale_node = helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT,
scale[0]);
zero_node = helper_->Constant({}, ONNX_NAMESPACE::TensorProto::INT8,
zeros[0]);
} else {
GetChannelWiseQuantizeInfo(weight_data, weight_shape, quantize_axis,
&scale, &zeros);
scale_node =
helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, scale);
zero_node =
helper_->Constant(ONNX_NAMESPACE::TensorProto::INT8, zeros);
}
QuantizeInfo matmul_weight_quantize_info(scale, zeros, scale_node,
zero_node, quantize_axis);
helper_->quantize_info[name] = matmul_weight_quantize_info;
Expand Down Expand Up @@ -169,13 +179,11 @@ void RKNNQuantizeProcessor::PerchannelToPerlayer() {

auto next_nodes = name2node_dict_[node->output(0)];
if (next_nodes.size() > 1 || IsGraphOutput(node->output(0))) {
P2OLogger() << "Type1" << std::endl;
continue;
}

auto add_node = next_nodes[0];
if (add_node->op_type() != "Add" || IsGraphOutput(add_node->output(0))) {
P2OLogger() << "Type2" << std::endl;
continue;
}

Expand Down

0 comments on commit 48fa34e

Please sign in to comment.