diff --git a/docs/op_matrix.md b/docs/op_matrix.md index a0de97be..4c435285 100644 --- a/docs/op_matrix.md +++ b/docs/op_matrix.md @@ -144,6 +144,7 @@ Operators that are implemented in Python | `aten::relu_` | | | `aten::remainder` | | | `aten::repeat` | | +| `aten::repeat_interleave` | dynamic repeats_tensor is not supported | | `aten::reshape` | | | `aten::roll` | | | `aten::round` | | diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index 7c457af3..24111b45 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -1814,6 +1814,90 @@ def model(x): tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output, check_stride=False) + def test_repeat_interleave_single_dim(self): + dummy_input = torch.randn(10, dtype=torch.float32) + + def model(x): + return torch.repeat_interleave(x, 4) + + model_path = get_model_path() + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + def test_repeat_interleave_single_dim_repeats(self): + dummy_input = torch.randn(10, dtype=torch.float32) + + def model(x): + return torch.repeat_interleave(x, torch.tensor([1, 2, 3, 2, 1, 2, 3, 4, 1, 2])) + + model_path = get_model_path() + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + def test_repeat_interleave_multi_dim(self): + dummy_input = torch.randn(5, 2, dtype=torch.float32) + + def model(x): + return torch.repeat_interleave(x, 4) + + model_path = get_model_path() + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + def test_repeat_interleave_multi_dim_repeats(self): + dummy_input = torch.randn(5, 2, dtype=torch.float32) + + def model(x): + return torch.repeat_interleave(x, torch.tensor([1, 2, 3, 2, 1, 2, 3, 4, 1, 2])) + + model_path = get_model_path() + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + def test_repeat_interleave_multi_dim_index(self): + dummy_input = torch.randn(5, 2, dtype=torch.float32) + + def model(x): + return torch.repeat_interleave(x, torch.tensor([1, 2, 1, 3, 2]), 0) + + model_path = get_model_path() + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + + def test_repeat_interleave_multi_dim_negative_index(self): + dummy_input = torch.randn(5, 2, dtype=torch.float32) + + def model(x): + return torch.repeat_interleave(x, torch.tensor([3, 2]), -1) + + model_path = get_model_path() + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output) + def test_repeat_single_dim(self): dummy_input = torch.randn(10, dtype=torch.float32) diff --git a/tinynn/converter/operators/torch/__init__.py b/tinynn/converter/operators/torch/__init__.py index c1ff35da..304ada9b 100644 --- a/tinynn/converter/operators/torch/__init__.py +++ b/tinynn/converter/operators/torch/__init__.py @@ -183,6 +183,7 @@ "aten::minimum": ATenMinimumOperator, "aten::index_put": ATenIndexPutOperator, "aten::index_put_": ATenIndexPutOperator, + "aten::repeat_interleave": ATenRepeatInterleaveOperator, # quantized "aten::quantize_per_tensor": ATenQuantizePerTensorOperator, "aten::fake_quantize_per_tensor_affine": ATenFakeQuantizePerTensorAffineOperator, diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 2ada3acd..eb9c6859 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -2632,6 +2632,67 @@ def parse(self, node, attrs, args, graph_converter): graph_converter.add_operator(op) +class ATenRepeatInterleaveOperator(ATenRepeatInterleaveSchema): + def parse(self, node, attrs, args, graph_converter): + super().parse(node, attrs, args, graph_converter) + + self.run(node) + + input_tensor = self.find_or_create_input(0, graph_converter) + outputs = self.to_tfl_tensors(self.output_names, self.output_tensors) + + if 'dim' in args: + dim = self.input_tensors[args['dim']] + else: + dim = None + + if 'repeats' in args: + repeats = self.input_tensors[args['repeats']] + else: + repeats = None + + if repeats is None: + size_repeats = input_tensor.tensor.size + raw_indices = torch.arange(size_repeats, dtype=torch.int32) + repeats_tensor = input_tensor + elif type(repeats) is int: + if dim is None: + size_repeats = input_tensor.tensor.size + else: + size_repeats = input_tensor.shape[dim] + raw_indices = torch.arange(size_repeats, dtype=torch.int32) + repeats_arr = torch.tensor(repeats, dtype=torch.int32) + repeats_tensor = self.create_attr_tensor(repeats_arr) + else: + if dim is None: + size_repeats = input_tensor.tensor.size + else: + size_repeats = input_tensor.shape[dim] + raw_indices = torch.arange(size_repeats, dtype=torch.int32) + repeats_tensor = self.find_or_create_input(args['repeats'], graph_converter) + + assert repeats_tensor.buffer is not None, "dynamic repeats_tensor is not supported" + + actual_indices = self.create_attr_tensor( + torch.repeat_interleave(raw_indices, torch.from_numpy(repeats_tensor.tensor).long()) + ) + + actual_input = input_tensor + if dim is None and len(input_tensor.shape) > 1: + new_shape = (input_tensor.tensor.size,) + shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32')) + actual_input = self.create_transform_tensor(np.reshape(input_tensor.tensor, new_shape)) + graph_converter.add_operator(tfl.ReshapeOperator([input_tensor, shape_tensor], [actual_input], new_shape)) + + inputs = [actual_input, actual_indices] + gather_dim = dim + if gather_dim is None: + gather_dim = 0 + if gather_dim < 0: + gather_dim += input_tensor.tensor.ndim + graph_converter.add_operator(tfl.GatherOperator(inputs, outputs, gather_dim)) + + class ATenMmOperator(ATenMmSchema): def parse(self, node, attrs, args, graph_converter): super().parse(node, attrs, args, graph_converter)