Skip to content

Commit

Permalink
[modifier] support PixelShuffle()
Browse files Browse the repository at this point in the history
  • Loading branch information
dinghuanghao committed Oct 17, 2023
1 parent 80ea5ce commit 4d7ff7c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/modifier_prune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,49 @@ def init_rnn_by_list(rnn, ch_value):


class ModifierTester(unittest.TestCase):
def test_pixel_shuffle_graph(self):
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.conv0 = nn.Conv2d(3, 16, (3, 3), padding=(1, 1))
self.pixelshuffle = nn.PixelShuffle(2)
self.conv1 = nn.Conv2d(4, 4, (3, 3), padding=(1, 1))

def forward(self, x):
x = self.conv0(x)
x = self.pixelshuffle(x)
x = self.conv1(x)

return x

def test_func():
model = TestModel()

ch_4 = get_rd_lst(4)
ch_16 = get_rd_lst(16)

init_conv_by_list(model.conv0, ch_16)
init_conv_by_list(model.conv1, ch_4)

importance_conv0 = l2_norm(model.conv0.weight, model.conv0).tolist()
importance_conv0_merge = [sum(importance_conv0[i : i + 4]) for i in range(0, len(importance_conv0), 4)]

conv0_idxes = get_topk(importance_conv0_merge, 2)

prune_idxes = []
for i in conv0_idxes:
prune_idxes += [j for j in range(i * 4, i * 4 + 4)]

pruner = OneShotChannelPruner(model, torch.ones(1, 3, 9, 9), {"sparsity": 0.5, "metrics": "l2_norm"})
pruner.register_mask()

m_conv0 = pruner.graph_modifier.get_modifier(model.conv0)

assert m_conv0.dim_changes_info.pruned_idx_o == prune_idxes

for i in range(20):
test_func()

def test_cat_graph(self):
class TestModel(nn.Module):
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions tinynn/graph/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2495,6 +2495,7 @@ def modify_output(self, remove_idx):
'mean': ReIndexModifier,
'sum': ReIndexModifier,
'getitem': ReIndexModifier,
nn.PixelShuffle: ReIndexModifier,
nn.RNN: RNNChannelModifier,
nn.GRU: RNNChannelModifier,
nn.LSTM: RNNChannelModifier,
Expand Down

0 comments on commit 4d7ff7c

Please sign in to comment.