Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
risemeup1 committed Aug 15, 2024
1 parent 93918ca commit eb89650
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions paddle2onnx/mapper/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ class Mapper {
}

template<typename T>
void GetScalarsAttr(const std::string &name, std::vector<T>* val){
void GetScalars(const std::string &name, std::vector<T>* val){
auto &op = parser_->GetOpDesc(block_idx_, op_idx_);
parser_->GetOpScalars(op, name, val);
parser_->GetOpScalarsAttr(op, name, val);
}

bool IsConstantInput(const std::string &input_key) const {
Expand Down
10 changes: 5 additions & 5 deletions paddle2onnx/mapper/tensor/assign_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ class AssignValueMapper : public Mapper {
const std::string attr_name = HasAttr("values") ? "values" : GetAttrNameByDtype(dtype);
std::unordered_map<int32_t, std::function<void()>> type_handlers = {
{P2ODataType::INT32, [&](){
if (attr_name == "values") GetScalarsAttr(attr_name, &int64_values_);
if (attr_name == "values") GetScalars(attr_name, &int64_values_);
else if (attr_name == "int32_values") GetAttr(attr_name, &int64_values_);
}},
{P2ODataType::INT64, [&](){
if (attr_name == "values") GetScalarsAttr(attr_name, &int64_values_);
if (attr_name == "values") GetScalars(attr_name, &int64_values_);
else if (attr_name == "int64_values") GetAttr(attr_name, &int64_values_);
}},
{P2ODataType::FP32, [&](){
if (attr_name == "values") GetScalarsAttr(attr_name, &fp32_values_);
if (attr_name == "values") GetScalars(attr_name, &fp32_values_);
else if (attr_name == "fp32_values") GetAttr(attr_name, &fp32_values_);
}},
{P2ODataType::FP64, [&](){
if (attr_name == "values") GetScalarsAttr(attr_name, &double_values_);
if (attr_name == "values") GetScalars(attr_name, &double_values_);
else if (attr_name == "fp32_values") GetAttr(attr_name, &double_values_);
}},
{P2ODataType::BOOL, [&](){
if (attr_name == "values") GetScalarsAttr(attr_name, &bool_values_);
if (attr_name == "values") GetScalars(attr_name, &bool_values_);
else if (attr_name == "bool_values") GetAttr(attr_name, &bool_values_);
}},
};
Expand Down
2 changes: 1 addition & 1 deletion paddle2onnx/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ bool PaddleParser::ExistsDumplicateTensorName() const {

#define DECLARE_GET_OP_SCALARS(scalar_type, target_type) \
template <> \
void PaddleParser::GetOpScalars<target_type>(const paddle2onnx::framework::proto::OpDesc& op, \
void PaddleParser::GetOpScalarsAttr<target_type>(const paddle2onnx::framework::proto::OpDesc& op, \
const std::string& name, \
std::vector<target_type>* res) const { \
bool found = false; \
Expand Down
2 changes: 1 addition & 1 deletion paddle2onnx/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class PaddleParser {
std::vector<T>* data) const;

template <typename T>
void GetOpScalars(const paddle2onnx::framework::proto::OpDesc& op,
void GetOpScalarsAttr(const paddle2onnx::framework::proto::OpDesc& op,
const std::string& name,
std::vector<T>* res) const;

Expand Down

0 comments on commit eb89650

Please sign in to comment.