From fb56fd944aa80bd48f1ef48970eb07286a64fdad Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Fri, 26 Apr 2024 11:04:42 +0800 Subject: [PATCH] [quantizer] fix dynamic shape for conv-bn2d rewrite (#298) * [quantizer] fix dynamic shape for conv-bn2d rewrite * lint fixes --- tinynn/graph/quantization/quantizer.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py index 11928739..f0012b14 100644 --- a/tinynn/graph/quantization/quantizer.py +++ b/tinynn/graph/quantization/quantizer.py @@ -2449,28 +2449,20 @@ def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs): graph.replace_node_module(node_fc, new_conv2d) graph.replace_node_module(node_bn1d, new_bn2d) - prev_tensor_shape = node_fc.prev_tensors[0].shape - prev_func = TraceFunction('torch.reshape', prefix='rewritten_conv2d_bn2d_').parse_args( - node_fc.prev_tensors[0], [prev_tensor_shape[0], prev_tensor_shape[1], 1, 1] + prev_func = TraceFunction('torch.Tensor.__getitem__', prefix='rewritten_conv2d_bn2d_').parse_args( + node_fc.prev_tensors[0], (Ellipsis, None, None) ) - next_tensor_shape = node_bn1d.next_tensors[0].shape - next_func = TraceFunction('torch.reshape', prefix='rewritten_conv2d_bn2d_').parse_args( - node_bn1d.next_tensors[0], [next_tensor_shape[0], next_tensor_shape[1]] + next_func = TraceFunction('torch.flatten', prefix='rewritten_conv2d_bn2d_').parse_args( + node_bn1d.next_tensors[0], 1 ) # expand the tensor shape between fc new_conv2d and new_bn2d node_fc.next_tensors[0].unsqueeze_(2).unsqueeze_(2) node_bn1d.prev_tensors[0].unsqueeze_(2).unsqueeze_(2) node_bn1d.next_tensors[0].unsqueeze_(2).unsqueeze_(2) - prev_out = torch.reshape( - node_fc.prev_tensors[0], - [node_fc.prev_tensors[0].shape[0], node_fc.prev_tensors[0].shape[1], 1, 1], - ) + prev_out = node_fc.prev_tensors[0][..., None, None] graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out]) - next_out = torch.reshape( - node_bn1d.next_tensors[0], - [node_bn1d.next_tensors[0].shape[0], node_bn1d.prev_tensors[0].shape[1]], - ) + next_out = torch.flatten(node_bn1d.next_tensors[0], 1) graph.insert_after(node_bn1d, next_func, [next_out]) # Rewrite BatchNorm1d to BatchNorm2d