Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #566 from aseemw/dev/fix_bug_constant_propagation
Browse files Browse the repository at this point in the history
Bug fix in constant propagation pass
  • Loading branch information
aseemw authored May 1, 2020
2 parents 1aa1feb + b575328 commit 8668144
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions onnx_coreml/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ class ConstantRemover(object):
'''
def __call__(self, graph): # type: (Graph) -> Graph
nodes_to_be_removed = []
graph_outputs = [o[0] for o in graph.outputs]
for node in graph.nodes:
are_all_inputs_constant = True
for input_ in node.inputs:
Expand All @@ -724,6 +725,13 @@ def __call__(self, graph): # type: (Graph) -> Graph
transformation_performed = False
if len(node.parents) != 0 or are_all_inputs_constant == False:
continue
is_graph_out = False
for out_ in node.outputs:
if out_ in graph_outputs:
is_graph_out = True
break
if is_graph_out:
continue
# TODO: Replace If -> ElIf with more general transformation block
if node.op_type == 'Gather':
data = node.input_tensors[node.inputs[0]]
Expand Down

0 comments on commit 8668144

Please sign in to comment.