diff --git a/tinynn/converter/operators/optimize.py b/tinynn/converter/operators/optimize.py index bdea53cf..3bdf17b2 100644 --- a/tinynn/converter/operators/optimize.py +++ b/tinynn/converter/operators/optimize.py @@ -1115,8 +1115,11 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: actions = [] remove_edges = [] remove_vertices = [] + processed_nodes = set() num_actions = 0 for node in unique_nodes: + pending_processed_nodes = set() + op = node['op'] input_indices = op_input_indices(op) l_shape = op.inputs[0].shape @@ -1195,6 +1198,7 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: prev_output_indices = [] num_constant_nodes = 0 prev_hints = set() + skip = False for i in input_indices: prev_node_name = op.inputs[i].name prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name]) @@ -1202,6 +1206,10 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: prev_output_indices.append(prev_node['outputs'].index(prev_node_name)) if prev_node['node_type'] == ExtendedOperator.TRANSPOSE: + if prev_node['name'] in processed_nodes: + skip = True + break + pending_processed_nodes.add(prev_node['name']) if mode == 'down': perm = tuple(prev_node['op'].inputs[1].tensor.tolist()) cand_perms.setdefault(perm, 0) @@ -1216,7 +1224,7 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE: num_constant_nodes += 1 - if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints: + if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints): continue next_nodes = [] @@ -1231,6 +1239,10 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE: out_nodes.append(next_node) else: + if next_node['name'] in processed_nodes: + skip = True + break + pending_processed_nodes.add(next_node['name']) next_nodes.append(next_node) next_edges.append(edge) @@ -1246,7 +1258,7 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: if 'direction' in next_node['op'].extra_hints: next_hints.add(next_node['op'].extra_hints['direction']) - if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints: + if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints): continue cur_transpose_size = sum(cand_perms.values()) + sum(cand_rev_perms.values()) @@ -1261,6 +1273,9 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: if 'down' in prev_hints or 'up' in next_hints: skip = False + if skip: + continue + perm = max(cand_perms.items(), key=lambda x: x[1])[0] perm_arr = np.array(perm, dtype='int32') @@ -1285,6 +1300,9 @@ def elementwise_reshape_transpose_passthrough_pass(self) -> int: remove_edges.extend([x.index for x in next_edges]) remove_vertices.extend([x.index for x in out_nodes]) + for pending_processed_node in pending_processed_nodes: + processed_nodes.add(pending_processed_node) + for n in out_nodes: del self.graph.tensor_map[n['outputs'][0]] del self.graph.tensor_node_map[n['outputs'][0]]