Skip to content

Commit

Permalink
Merge pull request #1350 from risemeup1/fix_assign_value_op_bug
Browse files Browse the repository at this point in the history
fix assign value op bug
  • Loading branch information
Jiang-Jia-Jun authored Aug 15, 2024
2 parents 4ca0809 + eb89650 commit c6dcef7
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 32 deletions.
30 changes: 7 additions & 23 deletions paddle2onnx/mapper/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,33 +183,17 @@ class Mapper {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
return parser_->OpHasAttr(op, name);
}
void GetAttr(const std::string &name, int64_t *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, float *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, bool *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::string *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::vector<int64_t> *val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::vector<float> *val) {

template<typename T>
void GetAttr(const std::string &name, T* val) {
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
}
void GetAttr(const std::string &name, std::vector<double> *val) {

template<typename T>
void GetScalars(const std::string &name, std::vector<T>* val){
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpAttr(op, name, val);
parser_->GetOpScalarsAttr(op, name, val);
}

bool IsConstantInput(const std::string &input_key) const {
Expand Down
63 changes: 55 additions & 8 deletions paddle2onnx/mapper/tensor/assign_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"
#include <unordered_map>
#include <functional>

namespace paddle2onnx {

Expand All @@ -27,21 +28,67 @@ class AssignValueMapper : public Mapper {
: Mapper(p, helper, block_id, op_id) {
GetAttr("dtype", &dtype_);
GetAttr("shape", &shape_);
GetAttrValues();
}
int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

private:
void GetAttrValues(){
int32_t dtype = static_cast<int32_t>(dtype_);
const std::string attr_name = HasAttr("values") ? "values" : GetAttrNameByDtype(dtype);
std::unordered_map<int32_t, std::function<void()>> type_handlers = {
{P2ODataType::INT32, [&](){
if (attr_name == "values") GetScalars(attr_name, &int64_values_);
else if (attr_name == "int32_values") GetAttr(attr_name, &int64_values_);
}},
{P2ODataType::INT64, [&](){
if (attr_name == "values") GetScalars(attr_name, &int64_values_);
else if (attr_name == "int64_values") GetAttr(attr_name, &int64_values_);
}},
{P2ODataType::FP32, [&](){
if (attr_name == "values") GetScalars(attr_name, &fp32_values_);
else if (attr_name == "fp32_values") GetAttr(attr_name, &fp32_values_);
}},
{P2ODataType::FP64, [&](){
if (attr_name == "values") GetScalars(attr_name, &double_values_);
else if (attr_name == "fp32_values") GetAttr(attr_name, &double_values_);
}},
{P2ODataType::BOOL, [&](){
if (attr_name == "values") GetScalars(attr_name, &bool_values_);
else if (attr_name == "bool_values") GetAttr(attr_name, &bool_values_);
}},
};

auto handler = type_handlers.find(dtype);
if (handler != type_handlers.end()) {
handler->second();
} else {
Error() << "Unsupported dtype value" << std::endl;
}
}

std::string GetAttrNameByDtype(int32_t dtype) {
if (dtype == P2ODataType::INT32) {
GetAttr("int32_values", &int64_values_);
} else if (dtype == P2ODataType::FP32) {
GetAttr("fp32_values", &fp32_values_);
return "int32_values";
} else if (dtype == P2ODataType::INT64) {
GetAttr("int64_values", &int64_values_);
return "int64_values";
}else if (dtype == P2ODataType::FP32) {
return "fp32_values";
} else if (dtype == P2ODataType::FP64) {
return "double_values";
} else if (dtype == P2ODataType::BOOL) {
return "bool_values";
}
Error() << "Unsupported dtype value" << std::endl;

}
int32_t GetMinOpsetVersion(bool verbose) override;
void Opset7() override;

private:
std::vector<float> fp32_values_;
std::vector<int64_t> int64_values_;
std::vector<bool> bool_values_;
std::vector<double> double_values_;
std::vector<int32_t> int32_values_;
std::vector<int64_t> shape_;
int64_t dtype_;
};
Expand Down
53 changes: 52 additions & 1 deletion paddle2onnx/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
Assert(found, "Cannot found attribute " + name + " in op: " + op.type());
}


void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name,
std::vector<double>* res) const {
Expand All @@ -759,7 +760,26 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
}
Assert(found, "Cannot found attribute " + name + " in op: " + op.type());
}

void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name,
std::vector<bool>* res) const {
bool found = false;
res->clear();
for (auto i = 0; i < op.attrs_size(); ++i) {
if (op.attrs(i).name() == name) {
found = true;
if (IsAttrVar(op, i)) break;
Assert(op.attrs(i).bools_size() >= 0,
"Cannot find list of double data from attr: " + name + " in op: " +
op.type());
for (auto j = 0; j < op.attrs(i).bools_size(); ++j) {
res->push_back(static_cast<double>(op.attrs(i).bools(j)));
}
break;
}
}
Assert(found, "Cannot found attribute " + name + " in op: " + op.type());
}
void PaddleParser::GetGlobalBlockInputOutputInfo() {
inputs.clear();
outputs.clear();
Expand Down Expand Up @@ -860,4 +880,35 @@ bool PaddleParser::ExistsDumplicateTensorName() const {
}
return false;
}

#define DECLARE_GET_OP_SCALARS(scalar_type, target_type) \
template <> \
void PaddleParser::GetOpScalarsAttr<target_type>(const paddle2onnx::framework::proto::OpDesc& op, \
const std::string& name, \
std::vector<target_type>* res) const { \
bool found = false; \
res->clear(); \
for (auto i = 0; i < op.attrs_size(); ++i) { \
if (op.attrs(i).name() == name) { \
found = true; \
if (IsAttrVar(op, i)) break; \
Assert(op.attrs(i).scalars_size() >= 0, \
"Cannot find list of scalars data from attr: " + name + \
" in op: " + op.type()); \
for (auto j = 0; j < op.attrs(i).scalars_size(); ++j) { \
Assert(op.attrs(i).scalars(j).has_##scalar_type(), \
"Scalar type does not match with " #scalar_type); \
res->push_back(static_cast<target_type>(op.attrs(i).scalars(j).scalar_type())); \
} \
break; \
} \
} \
Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); \
}

DECLARE_GET_OP_SCALARS(i, int64_t)
DECLARE_GET_OP_SCALARS(i, int32_t)
DECLARE_GET_OP_SCALARS(r, float)
DECLARE_GET_OP_SCALARS(r, double)
DECLARE_GET_OP_SCALARS(b, bool)
} // namespace paddle2onnx
7 changes: 7 additions & 0 deletions paddle2onnx/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class PaddleParser {
const std::string& name, std::vector<float>* res) const;
void GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name, std::vector<double>* res) const;
void GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name, std::vector<bool>* res) const;

bool IsConstantTensor(const int64_t& block_idx,
const std::string& tensor_name) const;
Expand All @@ -187,6 +189,11 @@ class PaddleParser {
const std::string& tensor_name,
std::vector<T>* data) const;

template <typename T>
void GetOpScalarsAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name,
std::vector<T>* res) const;

private:
// If the model has same output name in difference operators
// will fail to convert
Expand Down

0 comments on commit c6dcef7

Please sign in to comment.