Skip to content

Commit

Permalink
[converter] fix elementwise pass for pack ops
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 committed Sep 20, 2023
1 parent 4058486 commit a2f5fbf
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tinynn/converter/operators/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,10 +1811,14 @@ def elementwise_op_transpose_passthrough_pass(self, quantizable_ops_only: bool =
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])

# OP specific dim handling logic
if node['node_type'] in (ExtendedOperator.CONCATENATION, ExtendedOperator.GATHER):
if node['node_type'] in (ExtendedOperator.CONCATENATION, ExtendedOperator.GATHER, ExtendedOperator.UNPACK):
old_axis = op.axis
new_axis = np.where(inv_perm_arr == old_axis)[0][0]
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.PACK:
old_axis = op.axis
new_axis = np.where(inv_perm_arr_post == old_axis)[0][0]
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.SPLIT_V:
old_dim = op.inputs[2].tensor
new_dim = np.where(inv_perm_arr == old_dim)[0][0]
Expand Down

0 comments on commit a2f5fbf

Please sign in to comment.