Skip to content

Commit

Permalink
【Hackathon 7th No.46】 添加对返回常量的 IfElse 算子的支持 (#1383)
Browse files Browse the repository at this point in the history
* wip

* fix

* update due to comment

* Add missing implementation

* Restore code format

* Restore code format
  • Loading branch information
Asthestarsfalll authored Sep 20, 2024
1 parent 48fa34e commit 6b4bd2f
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 12 deletions.
61 changes: 52 additions & 9 deletions paddle2onnx/mapper/exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,24 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock(
temp_inputs, temp_outputs));
}

ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant(
const PaddleParser &parser, OnnxHelper *temp_helper, int32_t block_id,
int32_t op_id, const std::string &output_names) {
ONNX_NAMESPACE::GraphProto graph;
graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id));
auto op = parser.GetOpDesc(block_id, op_id); // fill_constant
auto out_info = parser.GetOpOutput(block_id, op_id, "Out");

*(graph.add_output()) = (*MakeValueInfo(out_info[0]));
for (auto &item : temp_helper->nodes) {
if (item->output(0) == output_names) {
*(graph.add_node()) = (*item.get());
break;
}
}

return std::move(graph);
}
ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
const PaddleParser &parser, int32_t block_id,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> &parameters,
Expand Down Expand Up @@ -328,23 +346,45 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
Assert(input_info.size() == 2,
"Only support when number of select_input's input_node is 2.");

// Build else sub graph
auto else_node_name = input_info[0].name;
auto conditional_block_cood_it = sub_block_map_.find(else_node_name);
Assert(conditional_block_cood_it != sub_block_map_.end(),
"Don't find select_input else_input node.");
"Can't find select_input else_input node.");
auto conditional_block_cood = conditional_block_cood_it->second;
auto else_graph =
ExportConditionalBlock(parser, conditional_block_cood.first,
conditional_block_cood.second, else_node_name);
ONNX_NAMESPACE::GraphProto else_graph, then_graph;
auto else_node = parser.GetOpDesc(conditional_block_cood.first,
conditional_block_cood.second);

if (else_node.type().find("conditional_block") != std::string::npos) {
else_graph = ExportConditionalBlock(
parser, conditional_block_cood.first, conditional_block_cood.second,
else_node_name);
} else {
else_graph = ExportFillConstant(
parser, &temp_helper, conditional_block_cood.first,
conditional_block_cood.second, else_node_name);
}

// Build then sub graph
auto then_node_name = input_info[1].name;
conditional_block_cood_it = sub_block_map_.find(then_node_name);
Assert(conditional_block_cood_it != sub_block_map_.end(),
"Don't find select_input then_input node.");
"Can't find select_input then_input node.");
conditional_block_cood = conditional_block_cood_it->second;
auto then_graph =
ExportConditionalBlock(parser, conditional_block_cood.first,
conditional_block_cood.second, then_node_name);
auto then_node = parser.GetOpDesc(conditional_block_cood.first,
conditional_block_cood.second);

// use node.type() to make sure correctness
if (then_node.type().find("conditional_block") != std::string::npos) {
then_graph = ExportConditionalBlock(
parser, conditional_block_cood.first, conditional_block_cood.second,
then_node_name);
} else {
then_graph = ExportFillConstant(
parser, &temp_helper, conditional_block_cood.first,
conditional_block_cood.second, then_node_name);
}

auto cond_info = parser.GetOpInput(block_id, op_id, "Mask");
auto output_info = parser.GetOpOutput(block_id, op_id, "Out");
Expand All @@ -355,6 +395,9 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
AddAttribute(node, "then_branch", then_graph);
AddAttribute(node, "else_branch", else_graph);
continue;
} else if (op.type() == "fill_constant") {
auto out_info = parser.GetOpOutput(block_id, op_id, "Out");
sub_block_map_[out_info[0].name] = {block_id, op_id};
}
ExportOp(parser, &temp_helper, opset_version_, block_id, op_id, verbose_);
}
Expand Down Expand Up @@ -784,4 +827,4 @@ ONNX_NAMESPACE::ModelProto ModelExporter::Optimize(
return ONNX_NAMESPACE::optimization::Optimize(model, passes);
}

} // namespace paddle2onnx
} // namespace paddle2onnx
4 changes: 4 additions & 0 deletions paddle2onnx/mapper/exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class ModelExporter {
ONNX_NAMESPACE::GraphProto ExportConditionalBlock(
const PaddleParser &parser, int32_t block_id, int32_t op_id,
const std::string &output_names);
ONNX_NAMESPACE::GraphProto ExportFillConstant(
const PaddleParser &parser, OnnxHelper *temp_helper,
int32_t block_id, int32_t op_id,
const std::string &output_names);
ONNX_NAMESPACE::GraphProto ExportBlock(
const PaddleParser &parser, int32_t block_id,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>> &parameters,
Expand Down
6 changes: 5 additions & 1 deletion tests/onnxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ def compare(result, expect, delta=1e-10, rtol=1e-10):
# Convert Paddle Tensor to Numpy array
if type(expect) == list:
expect = expect[0]
expect = expect.numpy()

if isinstance(expect, paddle.Tensor):
expect = expect.numpy()
else:
expect = np.array(expect)

# For result_shape is (1) and expect_shape shape is ()
expect = expect.squeeze()
Expand Down
82 changes: 80 additions & 2 deletions tests/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self):

def forward(self, cond, inputs):
if cond == 1:
return inputs * 1, inputs * 2
return inputs * 1, inputs * 2
else:
return inputs * 3, inputs * 4

Expand All @@ -64,8 +64,86 @@ def test_ifelse_2_false():
obj.set_input_data("input_data", paddle.to_tensor(2), paddle.to_tensor(1))
obj.run()

class BaseNet3(paddle.nn.Layer):
def __init__(self):
super(BaseNet3, self).__init__()

def forward(self, inputs):
if inputs == 1:
return 1
else:
return 2

def test_ifelse_3_true():
op = BaseNet3()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj.set_input_data("input_data", paddle.to_tensor(1))
obj.run()

def test_ifelse_3_false():
op = BaseNet3()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj.set_input_data("input_data", paddle.to_tensor(2))
obj.run()

class BaseNet4(paddle.nn.Layer):
def __init__(self):
super(BaseNet4, self).__init__()

def forward(self, inputs):
if inputs == 1:
return inputs + 1
else:
return 2

def test_ifelse_4_true():
op = BaseNet4()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj.set_input_data("input_data", paddle.to_tensor(1))
obj.run()

def test_ifelse_4_false():
op = BaseNet4()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj.set_input_data("input_data", paddle.to_tensor(2))
obj.run()

class BaseNet5(paddle.nn.Layer):
def __init__(self):
super(BaseNet5, self).__init__()

def forward(self, inputs):
if inputs == 1:
return 1, 2
else:
return 2, 3

def test_ifelse_5_true():
op = BaseNet5()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj.set_input_data("input_data", paddle.to_tensor(1))
obj.run()

def test_ifelse_5_false():
op = BaseNet5()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj.set_input_data("input_data", paddle.to_tensor(2))
obj.run()

if __name__ == "__main__":
test_ifelse_1_true()
test_ifelse_1_false()
test_ifelse_2_true()
test_ifelse_2_false()
test_ifelse_2_false()
test_ifelse_3_true()
test_ifelse_3_false()
test_ifelse_4_true()
test_ifelse_4_false()
test_ifelse_5_true()
test_ifelse_5_false()

0 comments on commit 6b4bd2f

Please sign in to comment.