diff --git a/docs/op_matrix.md b/docs/op_matrix.md index 014dd87b..72ace2a9 100644 --- a/docs/op_matrix.md +++ b/docs/op_matrix.md @@ -41,6 +41,7 @@ Operators that are implemented in Python | `aten::bitwise_not` | Only bools are supported in aten::bitwise_not | | `aten::bitwise_or` | | | `aten::bmm` | | +| `aten::broadcast_tensors` | | | `aten::cat` | | | `aten::chunk` | | | `aten::clamp` | | diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index e981b70f..856233fc 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -1802,6 +1802,23 @@ def model(x, y): tfl_output = tfl_run_model(model_path, inputs, dummy_output) assert_close(dummy_output, tfl_output, check_stride=False) + def test_broadcast_tensors(self): + dummy_input = torch.randn(1, 1, 64, dtype=torch.float32) + dummy_input_1 = torch.randn(3, 64, dtype=torch.float32) + + def model(x, y): + return torch.broadcast_tensors(x, y) + + inputs = [dummy_input, dummy_input_1] + + model_path = get_model_path() + converter = TFLiteConverter(model, inputs, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(*inputs) + tfl_output = tfl_run_model(model_path, inputs, dummy_output) + assert_close(dummy_output, tfl_output, check_stride=False) + def test_expand_simple(self): dummy_input = torch.randn(3, 1, dtype=torch.float32) diff --git a/tinynn/converter/operators/torch/__init__.py b/tinynn/converter/operators/torch/__init__.py index 7ffb11f5..cc7d781a 100644 --- a/tinynn/converter/operators/torch/__init__.py +++ b/tinynn/converter/operators/torch/__init__.py @@ -176,6 +176,7 @@ "aten::addbmm": ATenAddbmmOperator, "aten::baddbmm": ATenBaddbmmOperator, "aten::linalg_vector_norm": ATenLinalgVectorNormOperator, + "aten::broadcast_tensors": ATenBroadcastTensorsOperator, # quantized "aten::quantize_per_tensor": ATenQuantizePerTensorOperator, "aten::fake_quantize_per_tensor_affine": ATenFakeQuantizePerTensorAffineOperator, diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 46027d7b..4fcbd90d 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -3937,3 +3937,56 @@ def parse(self, node, attrs, args, graph_converter): for op in ops: graph_converter.add_operator(op) + + +class ATenBroadcastTensorsOperator(ATenBroadcastTensorsSchema): + def parse(self, node, attrs, args, graph_converter): + super().parse(node, attrs, args, graph_converter) + + self.run(node) + + input_names = graph_converter.get_list_expanded_names(self.input_names[0]) + inputs = self.to_tfl_tensors( + input_names, self.input_tensors[0], graph_converter=graph_converter, non_existent_as_buffer=True + ) + + output_names = [f'{self.output_names[0]}:{i}' for i in range(len(input_names))] + outputs = self.to_tfl_tensors(output_names, self.output_tensors[0]) + graph_converter.add_iterable_pair(self.output_names, output_names, 'input') + + ops = [] + for inp, outp in zip(inputs, outputs): + input_shape = inp.shape + output_shape = outp.shape + + # No-OP if input tensor is already of desired sizes + if output_shape == input_shape: + inputs = [inp, self.create_attr_tensor(inp.shape)] + + ops.append(tfl.ReshapeOperator(inputs, [outp], inp.shape)) + continue + + new_shape = input_shape + actual_input = inp + if len(output_shape) > len(input_shape): + new_shape = [1] * (len(output_shape) - len(input_shape)) + list(input_shape) + new_shape_arr = np.array(new_shape, dtype='int32') + new_shape_tensor = self.create_attr_tensor(new_shape_arr) + reshaped = self.create_transform_tensor(np.reshape(inp.tensor, new_shape_arr)) + actual_input = reshaped + reshape_op = tfl.ReshapeOperator([inp, new_shape_tensor], [reshaped], new_shape_arr) + reshape_op.extra_hints['direction'] = 'up' + ops.append(reshape_op) + + repeats = [] + for x, y in zip(new_shape, output_shape): + if x != y: + repeats.append(y) + else: + repeats.append(1) + + repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32')) + ops.append(tfl.TileOperator([actual_input, repeat_tensor], [outp])) + + for op in ops: + graph_converter.add_operator(op)