Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Commit

Permalink
terminate constant propagation if it hits a node that is a model output
Browse files Browse the repository at this point in the history
  • Loading branch information
aseemw committed May 1, 2020
1 parent 1aa1feb commit b575328
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions onnx_coreml/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ class ConstantRemover(object):
'''
def __call__(self, graph): # type: (Graph) -> Graph
nodes_to_be_removed = []
graph_outputs = [o[0] for o in graph.outputs]
for node in graph.nodes:
are_all_inputs_constant = True
for input_ in node.inputs:
Expand All @@ -724,6 +725,13 @@ def __call__(self, graph): # type: (Graph) -> Graph
transformation_performed = False
if len(node.parents) != 0 or are_all_inputs_constant == False:
continue
is_graph_out = False
for out_ in node.outputs:
if out_ in graph_outputs:
is_graph_out = True
break
if is_graph_out:
continue
# TODO: Replace If -> ElIf with more general transformation block
if node.op_type == 'Gather':
data = node.input_tensors[node.inputs[0]]
Expand Down

0 comments on commit b575328

Please sign in to comment.