From 4d7ff7c10d113afe60e4ea651551387121a5dbd0 Mon Sep 17 00:00:00 2001 From: dinghuanghao Date: Tue, 17 Oct 2023 17:21:24 +0800 Subject: [PATCH] [modifier] support PixelShuffle() --- tests/modifier_prune_test.py | 43 ++++++++++++++++++++++++++++++++++++ tinynn/graph/modifier.py | 1 + 2 files changed, 44 insertions(+) diff --git a/tests/modifier_prune_test.py b/tests/modifier_prune_test.py index 8fb5a81c..cfd73af7 100644 --- a/tests/modifier_prune_test.py +++ b/tests/modifier_prune_test.py @@ -83,6 +83,49 @@ def init_rnn_by_list(rnn, ch_value): class ModifierTester(unittest.TestCase): + def test_pixel_shuffle_graph(self): + class TestModel(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv2d(3, 16, (3, 3), padding=(1, 1)) + self.pixelshuffle = nn.PixelShuffle(2) + self.conv1 = nn.Conv2d(4, 4, (3, 3), padding=(1, 1)) + + def forward(self, x): + x = self.conv0(x) + x = self.pixelshuffle(x) + x = self.conv1(x) + + return x + + def test_func(): + model = TestModel() + + ch_4 = get_rd_lst(4) + ch_16 = get_rd_lst(16) + + init_conv_by_list(model.conv0, ch_16) + init_conv_by_list(model.conv1, ch_4) + + importance_conv0 = l2_norm(model.conv0.weight, model.conv0).tolist() + importance_conv0_merge = [sum(importance_conv0[i : i + 4]) for i in range(0, len(importance_conv0), 4)] + + conv0_idxes = get_topk(importance_conv0_merge, 2) + + prune_idxes = [] + for i in conv0_idxes: + prune_idxes += [j for j in range(i * 4, i * 4 + 4)] + + pruner = OneShotChannelPruner(model, torch.ones(1, 3, 9, 9), {"sparsity": 0.5, "metrics": "l2_norm"}) + pruner.register_mask() + + m_conv0 = pruner.graph_modifier.get_modifier(model.conv0) + + assert m_conv0.dim_changes_info.pruned_idx_o == prune_idxes + + for i in range(20): + test_func() + def test_cat_graph(self): class TestModel(nn.Module): def __init__(self): diff --git a/tinynn/graph/modifier.py b/tinynn/graph/modifier.py index 4078de02..07b95f17 100644 --- a/tinynn/graph/modifier.py +++ b/tinynn/graph/modifier.py @@ -2495,6 +2495,7 @@ def modify_output(self, remove_idx): 'mean': ReIndexModifier, 'sum': ReIndexModifier, 'getitem': ReIndexModifier, + nn.PixelShuffle: ReIndexModifier, nn.RNN: RNNChannelModifier, nn.GRU: RNNChannelModifier, nn.LSTM: RNNChannelModifier,