Skip to content

Commit

Permalink
[quantizer] fix dynamic shape for conv-bn2d rewrite (#298)
Browse files Browse the repository at this point in the history
* [quantizer] fix dynamic shape for conv-bn2d rewrite

* lint fixes
  • Loading branch information
peterjc123 authored Apr 26, 2024
1 parent 423e6b8 commit fb56fd9
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,28 +2449,20 @@ def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs):
graph.replace_node_module(node_fc, new_conv2d)
graph.replace_node_module(node_bn1d, new_bn2d)

prev_tensor_shape = node_fc.prev_tensors[0].shape
prev_func = TraceFunction('torch.reshape', prefix='rewritten_conv2d_bn2d_').parse_args(
node_fc.prev_tensors[0], [prev_tensor_shape[0], prev_tensor_shape[1], 1, 1]
prev_func = TraceFunction('torch.Tensor.__getitem__', prefix='rewritten_conv2d_bn2d_').parse_args(
node_fc.prev_tensors[0], (Ellipsis, None, None)
)
next_tensor_shape = node_bn1d.next_tensors[0].shape
next_func = TraceFunction('torch.reshape', prefix='rewritten_conv2d_bn2d_').parse_args(
node_bn1d.next_tensors[0], [next_tensor_shape[0], next_tensor_shape[1]]
next_func = TraceFunction('torch.flatten', prefix='rewritten_conv2d_bn2d_').parse_args(
node_bn1d.next_tensors[0], 1
)
# expand the tensor shape between fc new_conv2d and new_bn2d
node_fc.next_tensors[0].unsqueeze_(2).unsqueeze_(2)
node_bn1d.prev_tensors[0].unsqueeze_(2).unsqueeze_(2)
node_bn1d.next_tensors[0].unsqueeze_(2).unsqueeze_(2)

prev_out = torch.reshape(
node_fc.prev_tensors[0],
[node_fc.prev_tensors[0].shape[0], node_fc.prev_tensors[0].shape[1], 1, 1],
)
prev_out = node_fc.prev_tensors[0][..., None, None]
graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out])
next_out = torch.reshape(
node_bn1d.next_tensors[0],
[node_bn1d.next_tensors[0].shape[0], node_bn1d.prev_tensors[0].shape[1]],
)
next_out = torch.flatten(node_bn1d.next_tensors[0], 1)
graph.insert_after(node_bn1d, next_func, [next_out])

# Rewrite BatchNorm1d to BatchNorm2d
Expand Down

0 comments on commit fb56fd9

Please sign in to comment.