Skip to content

Commit

Permalink
[converter] support aten::index for multi dims (#286)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored Mar 26, 2024
1 parent b7a07e4 commit 092398d
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Operators that are implemented in Python
| `aten::hardtanh` | |
| `aten::hardtanh_` | |
| `aten::im2col` | only 4-D input tensors (batched image-like tensors) are supported |
| `aten::index` | Multiple indices for aten::index is not supported |
| `aten::index` | |
| `aten::index_put` | aten::index_put_ with accumulate=True is not supported |
| `aten::index_put_` | aten::index_put_ with accumulate=True is not supported |
| `aten::index_select` | |
Expand Down
45 changes: 45 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4127,6 +4127,51 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index_multi_dim_complex(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return x[[3, 4], [2, 3]]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index_multi_dim_complex_1(self):
dummy_input = torch.randn(10, 10, 10, dtype=torch.float32)

def model(x):
return x[[3, 4], [2, 3]]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_index_multi_dim_complex_2(self):
dummy_input = torch.randn(10, 10, 10, dtype=torch.float32)

def model(x):
return x[[3, 4], [2], [1, 5]]

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

def test_gather(self):
dummy_input = torch.randn(10, dtype=torch.float32)

Expand Down
126 changes: 101 additions & 25 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2437,38 +2437,114 @@ def parse(self, node, attrs, args, graph_converter):

filtered_dims = [i for i, idx in enumerate(indices) if idx is not None]
assert all((indices[i].dtype in (torch.int64, torch.int32) for i in filtered_dims))
assert len(filtered_dims) == 1, "Multiple indices for aten::index is not supported"

try:
names = graph_converter.get_list_expanded_names(self.input_names[1])
except KeyError:
names = [self.get_unique_attr_name() for _ in indices]

filtered_names = [names[i] for i in filtered_dims]
filtered_tensors = [indices[i].to(dtype=torch.int32) for i in filtered_dims]

input_tensor = self.find_or_create_input(0, graph_converter)
# TODO: support negative tensor indices
filtered_tensors = [
t + (t < 0).int() * input_tensor.shape[i] if n not in graph_converter.tensor_map else t
for i, n, t in zip(filtered_dims, filtered_names, filtered_tensors)
]
indice_tensors = self.to_tfl_tensors(
filtered_names, filtered_tensors, graph_converter=graph_converter, non_existent_as_buffer=True
)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)

actual_input = input_tensor
actual_output = None
for i, (dim, idx) in enumerate(zip(filtered_dims, indice_tensors)):
if i == len(filtered_dims) - 1:
actual_output = outputs[0]
if len(filtered_dims) > 1:
if graph_converter.has_nested_names(self.input_names[1]):
input_names = graph_converter.get_list_expanded_names(self.input_names[1])
indices_tensors = self.to_tfl_tensors(
input_names, self.input_tensors[1], graph_converter=graph_converter, non_existent_as_buffer=True
)
else:
if type(self.input_tensors[1]) in (tuple, list):
indices_tensors = [self.create_attr_tensor(x) for x in self.input_tensors[1]]
else:
indices_tensors = [self.find_or_create_input(1, graph_converter)]

dim = input_tensor.tensor.ndim

indices_shape = [x.tensor.size for x in indices_tensors]
max_len = max(indices_shape)
indices_shape_tensor = torch.tensor(indices_shape)
left_indices = (
torch.arange(max_len).view(-1, 1).expand(-1, len(indices_shape)) % indices_shape_tensor
).int()
all_indices_shape = list(outputs[0].shape) + [dim]

if len(indices_tensors) < dim:
pad_shape = list(input_tensor.shape[len(indices_tensors) :])
pad_indices = torch.ones(pad_shape).nonzero().int()
left_len = len(indices_shape)
right_len = len(pad_shape)
left_size = left_indices.size(0)
right_size = pad_indices.size(0)
left_reshaped = (
left_indices.view(-1, 1, left_len).expand(-1, right_size, left_len).reshape(-1, left_len)
)
right_reshaped = (
pad_indices.view(1, -1, right_len).expand(left_size, -1, right_len).reshape(-1, right_len)
)
all_indices = torch.cat([left_reshaped, right_reshaped], 1).view(all_indices_shape).unbind(-1)
else:
all_indices = left_indices.view(all_indices_shape).unbind(-1)

new_indices = []
for i in range(dim):
if i < len(indices_tensors):
idx_tensor = indices_tensors[i]
actual_idx = np.take(idx_tensor.tensor, all_indices[i].numpy())
else:
actual_idx = all_indices[i].numpy()
if idx_tensor.buffer is None and i < len(indices_tensors):
actual_idx_t = self.create_transform_tensor(actual_idx)
fake_idx_t = self.create_attr_tensor(all_indices[i].numpy())
graph_converter.add_operator(tfl.GatherOperator([idx_tensor, fake_idx_t], [actual_idx_t], axis=0))

if str(actual_idx_t.dtype) != 'int32':
index_casted = self.create_transform_tensor(actual_idx_t.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[actual_idx_t],
[index_casted],
tfl.numpy_tflite_dtype_mappings[str(actual_idx_t.dtype)],
tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
)
)
actual_idx_t = index_casted
new_indices.append(actual_idx_t)
else:
new_indices.append(self.create_attr_tensor(actual_idx.astype(np.int32)))

index_arr = np.stack([x.tensor for x in new_indices], -1)
if all((x.buffer is not None for x in new_indices)):
index_tensor = self.create_attr_tensor(index_arr)
else:
actual_output = self.create_transform_tensor(np.take(actual_input.tensor, idx.tensor, axis=dim))
index_tensor = self.create_transform_tensor(index_arr)
graph_converter.add_operator(
tfl.PackOperator(new_indices, [index_tensor], dim, axis=index_tensor.tensor.ndim - 1)
)

graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, index_tensor], outputs))
else:
try:
names = graph_converter.get_list_expanded_names(self.input_names[1])
except KeyError:
names = [self.get_unique_attr_name() for _ in indices]

filtered_names = [names[i] for i in filtered_dims]
filtered_tensors = [indices[i].to(dtype=torch.int32) for i in filtered_dims]

filtered_tensors = [
t + (t < 0).int() * input_tensor.shape[i] if n not in graph_converter.tensor_map else t
for i, n, t in zip(filtered_dims, filtered_names, filtered_tensors)
]
indice_tensors = self.to_tfl_tensors(
filtered_names, filtered_tensors, graph_converter=graph_converter, non_existent_as_buffer=True
)

actual_input = input_tensor
actual_output = None
for i, (dim, idx) in enumerate(zip(filtered_dims, indice_tensors)):
if i == len(filtered_dims) - 1:
actual_output = outputs[0]
else:
actual_output = self.create_transform_tensor(np.take(actual_input.tensor, idx.tensor, axis=dim))

graph_converter.add_operator(tfl.GatherOperator([actual_input, idx], [actual_output], axis=dim))
graph_converter.add_operator(tfl.GatherOperator([actual_input, idx], [actual_output], axis=dim))

actual_input = actual_output
actual_input = actual_output


class ATenIndexSelectOperator(ATenIndexSelectSchema):
Expand Down

0 comments on commit 092398d

Please sign in to comment.