Skip to content

Commit

Permalink
[converter] add aten::repeat_interleave (#288)
Browse files Browse the repository at this point in the history
* [converter] add aten::repeat_interleave

* fixes

* more fixes
  • Loading branch information
peterjc123 authored Mar 27, 2024
1 parent 092398d commit d8b2a34
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Operators that are implemented in Python
| `aten::relu_` | |
| `aten::remainder` | |
| `aten::repeat` | |
| `aten::repeat_interleave` | dynamic repeats_tensor is not supported |
| `aten::reshape` | |
| `aten::roll` | |
| `aten::round` | |
Expand Down
84 changes: 84 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,90 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_repeat_interleave_single_dim(self):
dummy_input = torch.randn(10, dtype=torch.float32)

def model(x):
return torch.repeat_interleave(x, 4)

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_repeat_interleave_single_dim_repeats(self):
dummy_input = torch.randn(10, dtype=torch.float32)

def model(x):
return torch.repeat_interleave(x, torch.tensor([1, 2, 3, 2, 1, 2, 3, 4, 1, 2]))

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_repeat_interleave_multi_dim(self):
dummy_input = torch.randn(5, 2, dtype=torch.float32)

def model(x):
return torch.repeat_interleave(x, 4)

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_repeat_interleave_multi_dim_repeats(self):
dummy_input = torch.randn(5, 2, dtype=torch.float32)

def model(x):
return torch.repeat_interleave(x, torch.tensor([1, 2, 3, 2, 1, 2, 3, 4, 1, 2]))

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_repeat_interleave_multi_dim_index(self):
dummy_input = torch.randn(5, 2, dtype=torch.float32)

def model(x):
return torch.repeat_interleave(x, torch.tensor([1, 2, 1, 3, 2]), 0)

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_repeat_interleave_multi_dim_negative_index(self):
dummy_input = torch.randn(5, 2, dtype=torch.float32)

def model(x):
return torch.repeat_interleave(x, torch.tensor([3, 2]), -1)

model_path = get_model_path()
converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_repeat_single_dim(self):
dummy_input = torch.randn(10, dtype=torch.float32)

Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
"aten::minimum": ATenMinimumOperator,
"aten::index_put": ATenIndexPutOperator,
"aten::index_put_": ATenIndexPutOperator,
"aten::repeat_interleave": ATenRepeatInterleaveOperator,
# quantized
"aten::quantize_per_tensor": ATenQuantizePerTensorOperator,
"aten::fake_quantize_per_tensor_affine": ATenFakeQuantizePerTensorAffineOperator,
Expand Down
61 changes: 61 additions & 0 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,6 +2632,67 @@ def parse(self, node, attrs, args, graph_converter):
graph_converter.add_operator(op)


class ATenRepeatInterleaveOperator(ATenRepeatInterleaveSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)

input_tensor = self.find_or_create_input(0, graph_converter)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

if 'dim' in args:
dim = self.input_tensors[args['dim']]
else:
dim = None

if 'repeats' in args:
repeats = self.input_tensors[args['repeats']]
else:
repeats = None

if repeats is None:
size_repeats = input_tensor.tensor.size
raw_indices = torch.arange(size_repeats, dtype=torch.int32)
repeats_tensor = input_tensor
elif type(repeats) is int:
if dim is None:
size_repeats = input_tensor.tensor.size
else:
size_repeats = input_tensor.shape[dim]
raw_indices = torch.arange(size_repeats, dtype=torch.int32)
repeats_arr = torch.tensor(repeats, dtype=torch.int32)
repeats_tensor = self.create_attr_tensor(repeats_arr)
else:
if dim is None:
size_repeats = input_tensor.tensor.size
else:
size_repeats = input_tensor.shape[dim]
raw_indices = torch.arange(size_repeats, dtype=torch.int32)
repeats_tensor = self.find_or_create_input(args['repeats'], graph_converter)

assert repeats_tensor.buffer is not None, "dynamic repeats_tensor is not supported"

actual_indices = self.create_attr_tensor(
torch.repeat_interleave(raw_indices, torch.from_numpy(repeats_tensor.tensor).long())
)

actual_input = input_tensor
if dim is None and len(input_tensor.shape) > 1:
new_shape = (input_tensor.tensor.size,)
shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
actual_input = self.create_transform_tensor(np.reshape(input_tensor.tensor, new_shape))
graph_converter.add_operator(tfl.ReshapeOperator([input_tensor, shape_tensor], [actual_input], new_shape))

inputs = [actual_input, actual_indices]
gather_dim = dim
if gather_dim is None:
gather_dim = 0
if gather_dim < 0:
gather_dim += input_tensor.tensor.ndim
graph_converter.add_operator(tfl.GatherOperator(inputs, outputs, gather_dim))


class ATenMmOperator(ATenMmSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
Expand Down

0 comments on commit d8b2a34

Please sign in to comment.