From b575328ca6b3e78c4f6bf25c62a739ff0945bb93 Mon Sep 17 00:00:00 2001 From: aseemw Date: Thu, 30 Apr 2020 21:40:22 -0700 Subject: [PATCH] terminate constant propagation if it hits a node that is a model output --- onnx_coreml/_transformers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnx_coreml/_transformers.py b/onnx_coreml/_transformers.py index 7492adc..bd954e6 100644 --- a/onnx_coreml/_transformers.py +++ b/onnx_coreml/_transformers.py @@ -714,6 +714,7 @@ class ConstantRemover(object): ''' def __call__(self, graph): # type: (Graph) -> Graph nodes_to_be_removed = [] + graph_outputs = [o[0] for o in graph.outputs] for node in graph.nodes: are_all_inputs_constant = True for input_ in node.inputs: @@ -724,6 +725,13 @@ def __call__(self, graph): # type: (Graph) -> Graph transformation_performed = False if len(node.parents) != 0 or are_all_inputs_constant == False: continue + is_graph_out = False + for out_ in node.outputs: + if out_ in graph_outputs: + is_graph_out = True + break + if is_graph_out: + continue # TODO: Replace If -> ElIf with more general transformation block if node.op_type == 'Gather': data = node.input_tensors[node.inputs[0]]