diff --git a/paddle2onnx/mapper/mapper.h b/paddle2onnx/mapper/mapper.h index dd551fc98..58281286a 100755 --- a/paddle2onnx/mapper/mapper.h +++ b/paddle2onnx/mapper/mapper.h @@ -183,33 +183,17 @@ class Mapper { auto &op = parser_->GetOpDesc(block_idx_, op_idx_); return parser_->OpHasAttr(op, name); } - void GetAttr(const std::string &name, int64_t *val) { - auto &op = parser_->GetOpDesc(block_idx_, op_idx_); - parser_->GetOpAttr(op, name, val); - } - void GetAttr(const std::string &name, float *val) { - auto &op = parser_->GetOpDesc(block_idx_, op_idx_); - parser_->GetOpAttr(op, name, val); - } - void GetAttr(const std::string &name, bool *val) { - auto &op = parser_->GetOpDesc(block_idx_, op_idx_); - parser_->GetOpAttr(op, name, val); - } - void GetAttr(const std::string &name, std::string *val) { - auto &op = parser_->GetOpDesc(block_idx_, op_idx_); - parser_->GetOpAttr(op, name, val); - } - void GetAttr(const std::string &name, std::vector *val) { - auto &op = parser_->GetOpDesc(block_idx_, op_idx_); - parser_->GetOpAttr(op, name, val); - } - void GetAttr(const std::string &name, std::vector *val) { + + template + void GetAttr(const std::string &name, T* val) { auto &op = parser_->GetOpDesc(block_idx_, op_idx_); parser_->GetOpAttr(op, name, val); } - void GetAttr(const std::string &name, std::vector *val) { + + template + void GetScalars(const std::string &name, std::vector* val){ auto &op = parser_->GetOpDesc(block_idx_, op_idx_); - parser_->GetOpAttr(op, name, val); + parser_->GetOpScalarsAttr(op, name, val); } bool IsConstantInput(const std::string &input_key) const { diff --git a/paddle2onnx/mapper/tensor/assign_value.h b/paddle2onnx/mapper/tensor/assign_value.h index 3238612fb..64cf85384 100644 --- a/paddle2onnx/mapper/tensor/assign_value.h +++ b/paddle2onnx/mapper/tensor/assign_value.h @@ -15,8 +15,9 @@ #pragma once #include #include - #include "paddle2onnx/mapper/mapper.h" +#include +#include namespace paddle2onnx { @@ -27,21 +28,67 @@ class AssignValueMapper : public Mapper { : Mapper(p, helper, block_id, op_id) { GetAttr("dtype", &dtype_); GetAttr("shape", &shape_); + GetAttrValues(); + } + int32_t GetMinOpsetVersion(bool verbose) override; + void Opset7() override; + + private: + void GetAttrValues(){ int32_t dtype = static_cast(dtype_); + const std::string attr_name = HasAttr("values") ? "values" : GetAttrNameByDtype(dtype); + std::unordered_map> type_handlers = { + {P2ODataType::INT32, [&](){ + if (attr_name == "values") GetScalars(attr_name, &int64_values_); + else if (attr_name == "int32_values") GetAttr(attr_name, &int64_values_); + }}, + {P2ODataType::INT64, [&](){ + if (attr_name == "values") GetScalars(attr_name, &int64_values_); + else if (attr_name == "int64_values") GetAttr(attr_name, &int64_values_); + }}, + {P2ODataType::FP32, [&](){ + if (attr_name == "values") GetScalars(attr_name, &fp32_values_); + else if (attr_name == "fp32_values") GetAttr(attr_name, &fp32_values_); + }}, + {P2ODataType::FP64, [&](){ + if (attr_name == "values") GetScalars(attr_name, &double_values_); + else if (attr_name == "fp32_values") GetAttr(attr_name, &double_values_); + }}, + {P2ODataType::BOOL, [&](){ + if (attr_name == "values") GetScalars(attr_name, &bool_values_); + else if (attr_name == "bool_values") GetAttr(attr_name, &bool_values_); + }}, + }; + + auto handler = type_handlers.find(dtype); + if (handler != type_handlers.end()) { + handler->second(); + } else { + Error() << "Unsupported dtype value" << std::endl; + } + } + + std::string GetAttrNameByDtype(int32_t dtype) { if (dtype == P2ODataType::INT32) { - GetAttr("int32_values", &int64_values_); - } else if (dtype == P2ODataType::FP32) { - GetAttr("fp32_values", &fp32_values_); + return "int32_values"; } else if (dtype == P2ODataType::INT64) { - GetAttr("int64_values", &int64_values_); + return "int64_values"; + }else if (dtype == P2ODataType::FP32) { + return "fp32_values"; + } else if (dtype == P2ODataType::FP64) { + return "double_values"; + } else if (dtype == P2ODataType::BOOL) { + return "bool_values"; } + Error() << "Unsupported dtype value" << std::endl; + } - int32_t GetMinOpsetVersion(bool verbose) override; - void Opset7() override; - private: std::vector fp32_values_; std::vector int64_values_; + std::vector bool_values_; + std::vector double_values_; + std::vector int32_values_; std::vector shape_; int64_t dtype_; }; diff --git a/paddle2onnx/parser/parser.cc b/paddle2onnx/parser/parser.cc index c2c03019c..e8c6c0d8a 100755 --- a/paddle2onnx/parser/parser.cc +++ b/paddle2onnx/parser/parser.cc @@ -739,6 +739,7 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); } + void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, const std::string& name, std::vector* res) const { @@ -759,7 +760,26 @@ void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, } Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); } - +void PaddleParser::GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, + const std::string& name, + std::vector* res) const { + bool found = false; + res->clear(); + for (auto i = 0; i < op.attrs_size(); ++i) { + if (op.attrs(i).name() == name) { + found = true; + if (IsAttrVar(op, i)) break; + Assert(op.attrs(i).bools_size() >= 0, + "Cannot find list of double data from attr: " + name + " in op: " + + op.type()); + for (auto j = 0; j < op.attrs(i).bools_size(); ++j) { + res->push_back(static_cast(op.attrs(i).bools(j))); + } + break; + } + } + Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); +} void PaddleParser::GetGlobalBlockInputOutputInfo() { inputs.clear(); outputs.clear(); @@ -860,4 +880,35 @@ bool PaddleParser::ExistsDumplicateTensorName() const { } return false; } + +#define DECLARE_GET_OP_SCALARS(scalar_type, target_type) \ +template <> \ +void PaddleParser::GetOpScalarsAttr(const paddle2onnx::framework::proto::OpDesc& op, \ + const std::string& name, \ + std::vector* res) const { \ + bool found = false; \ + res->clear(); \ + for (auto i = 0; i < op.attrs_size(); ++i) { \ + if (op.attrs(i).name() == name) { \ + found = true; \ + if (IsAttrVar(op, i)) break; \ + Assert(op.attrs(i).scalars_size() >= 0, \ + "Cannot find list of scalars data from attr: " + name + \ + " in op: " + op.type()); \ + for (auto j = 0; j < op.attrs(i).scalars_size(); ++j) { \ + Assert(op.attrs(i).scalars(j).has_##scalar_type(), \ + "Scalar type does not match with " #scalar_type); \ + res->push_back(static_cast(op.attrs(i).scalars(j).scalar_type())); \ + } \ + break; \ + } \ + } \ + Assert(found, "Cannot found attribute " + name + " in op: " + op.type()); \ +} + +DECLARE_GET_OP_SCALARS(i, int64_t) +DECLARE_GET_OP_SCALARS(i, int32_t) +DECLARE_GET_OP_SCALARS(r, float) +DECLARE_GET_OP_SCALARS(r, double) +DECLARE_GET_OP_SCALARS(b, bool) } // namespace paddle2onnx diff --git a/paddle2onnx/parser/parser.h b/paddle2onnx/parser/parser.h index 0d15a72bc..3812bea0e 100644 --- a/paddle2onnx/parser/parser.h +++ b/paddle2onnx/parser/parser.h @@ -179,6 +179,8 @@ class PaddleParser { const std::string& name, std::vector* res) const; void GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, const std::string& name, std::vector* res) const; + void GetOpAttr(const paddle2onnx::framework::proto::OpDesc& op, + const std::string& name, std::vector* res) const; bool IsConstantTensor(const int64_t& block_idx, const std::string& tensor_name) const; @@ -187,6 +189,11 @@ class PaddleParser { const std::string& tensor_name, std::vector* data) const; + template + void GetOpScalarsAttr(const paddle2onnx::framework::proto::OpDesc& op, + const std::string& name, + std::vector* res) const; + private: // If the model has same output name in difference operators // will fail to convert