Skip to content

Commit

Permalink
[converter] add aten::broadcast_tensors (#229)
Browse files Browse the repository at this point in the history
* [converter] add aten::broadcast_tensors

* [misc] revert changes

* [docs] update op matrix
  • Loading branch information
peterjc123 authored Jun 6, 2023
1 parent 0829020 commit 88e49e6
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Operators that are implemented in Python
| `aten::bitwise_not` | Only bools are supported in aten::bitwise_not |
| `aten::bitwise_or` | |
| `aten::bmm` | |
| `aten::broadcast_tensors` | |
| `aten::cat` | |
| `aten::chunk` | |
| `aten::clamp` | |
Expand Down
17 changes: 17 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,23 @@ def model(x, y):
tfl_output = tfl_run_model(model_path, inputs, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_broadcast_tensors(self):
dummy_input = torch.randn(1, 1, 64, dtype=torch.float32)
dummy_input_1 = torch.randn(3, 64, dtype=torch.float32)

def model(x, y):
return torch.broadcast_tensors(x, y)

inputs = [dummy_input, dummy_input_1]

model_path = get_model_path()
converter = TFLiteConverter(model, inputs, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(*inputs)
tfl_output = tfl_run_model(model_path, inputs, dummy_output)
assert_close(dummy_output, tfl_output, check_stride=False)

def test_expand_simple(self):
dummy_input = torch.randn(3, 1, dtype=torch.float32)

Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"aten::addbmm": ATenAddbmmOperator,
"aten::baddbmm": ATenBaddbmmOperator,
"aten::linalg_vector_norm": ATenLinalgVectorNormOperator,
"aten::broadcast_tensors": ATenBroadcastTensorsOperator,
# quantized
"aten::quantize_per_tensor": ATenQuantizePerTensorOperator,
"aten::fake_quantize_per_tensor_affine": ATenFakeQuantizePerTensorAffineOperator,
Expand Down
53 changes: 53 additions & 0 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3937,3 +3937,56 @@ def parse(self, node, attrs, args, graph_converter):

for op in ops:
graph_converter.add_operator(op)


class ATenBroadcastTensorsOperator(ATenBroadcastTensorsSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)

input_names = graph_converter.get_list_expanded_names(self.input_names[0])
inputs = self.to_tfl_tensors(
input_names, self.input_tensors[0], graph_converter=graph_converter, non_existent_as_buffer=True
)

output_names = [f'{self.output_names[0]}:{i}' for i in range(len(input_names))]
outputs = self.to_tfl_tensors(output_names, self.output_tensors[0])
graph_converter.add_iterable_pair(self.output_names, output_names, 'input')

ops = []
for inp, outp in zip(inputs, outputs):
input_shape = inp.shape
output_shape = outp.shape

# No-OP if input tensor is already of desired sizes
if output_shape == input_shape:
inputs = [inp, self.create_attr_tensor(inp.shape)]

ops.append(tfl.ReshapeOperator(inputs, [outp], inp.shape))
continue

new_shape = input_shape
actual_input = inp
if len(output_shape) > len(input_shape):
new_shape = [1] * (len(output_shape) - len(input_shape)) + list(input_shape)
new_shape_arr = np.array(new_shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
reshaped = self.create_transform_tensor(np.reshape(inp.tensor, new_shape_arr))
actual_input = reshaped
reshape_op = tfl.ReshapeOperator([inp, new_shape_tensor], [reshaped], new_shape_arr)
reshape_op.extra_hints['direction'] = 'up'
ops.append(reshape_op)

repeats = []
for x, y in zip(new_shape, output_shape):
if x != y:
repeats.append(y)
else:
repeats.append(1)

repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
ops.append(tfl.TileOperator([actual_input, repeat_tensor], [outp]))

for op in ops:
graph_converter.add_operator(op)

0 comments on commit 88e49e6

Please sign in to comment.