diff --git a/tests/converter_optimizer_test.py b/tests/converter_optimizer_test.py index d4ac84e1..af6c8027 100644 --- a/tests/converter_optimizer_test.py +++ b/tests/converter_optimizer_test.py @@ -1383,7 +1383,7 @@ def forward(self, x): tflite.BuiltinOperator.CONCATENATION, ) self.assertEqual(tfl_model.Subgraphs(0).Operators(5).OutputsLength(), 1) - + def test_fuse_transposeconv_relu(self): class TestModel(nn.Module): def __init__(self) -> None: @@ -1394,7 +1394,7 @@ def __init__(self) -> None: def forward(self, x): y = self.act(self.transposeconv(x)) - return y + return y model = TestModel() model.eval() @@ -1403,7 +1403,7 @@ def forward(self, x): converter = TFLiteConverter(model, dummy_input, model_path) converter.convert() - + tfl_model = parse_model(model_path) self.assertEqual(tfl_model.OperatorCodesLength(), 1) self.assertEqual(tfl_model.OperatorCodes(0).DeprecatedBuiltinCode(), tflite.BuiltinOperator.TRANSPOSE_CONV) diff --git a/tinynn/converter/base.py b/tinynn/converter/base.py index b664f4a9..27dca9bb 100644 --- a/tinynn/converter/base.py +++ b/tinynn/converter/base.py @@ -58,6 +58,7 @@ def __init__( hybrid_gen_single_op_models: bool = False, hybrid_config: typing.Optional[typing.Dict[str, bool]] = None, group_tensors: bool = False, + missing_outputs_as_constants: bool = False, ) -> None: """ The TFLiteConverter class @@ -110,6 +111,7 @@ def __init__( hybrid_gen_single_op_models: Generate both floating point and quantized version of the model for hybrid \ quantizable ops. Defaults to False group_tensors (bool): Group tensors to save space. Defaults to False + missing_outputs_as_constants (bool): View missing outputs as constants. Defaults to False """ self.model = model @@ -117,7 +119,7 @@ def __init__( self.graph = None self.tensor_map = {} self.tensor_map_copies = {} - self.common_graph = CommonGraph() + self.common_graph = CommonGraph(missing_outputs_as_constants) if type(dummy_input) in (tuple, list): self.dummy_input = dummy_input diff --git a/tinynn/converter/operators/graph.py b/tinynn/converter/operators/graph.py index 9a0efebf..23b790cd 100644 --- a/tinynn/converter/operators/graph.py +++ b/tinynn/converter/operators/graph.py @@ -24,8 +24,9 @@ class CommonGraph(object): input_transpose: typing.List[bool] output_transpose: typing.Union[typing.List[typing.Optional[bool]], typing.Optional[bool]] node_op_counter: int + missing_outputs_as_constants: bool - def __init__(self) -> None: + def __init__(self, missing_outputs_as_constants: bool) -> None: self.graph = ig.Graph(directed=True) self.tensor_map = dict() self.tensor_node_map = dict() @@ -39,6 +40,8 @@ def __init__(self) -> None: self.transform_store = {} self.constant_mapping = {} + self.missing_outputs_as_constants = missing_outputs_as_constants + def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str): self.transform_store.setdefault(tensor_name, {}) self.transform_store[tensor_name][transform_name] = new_tensor_name @@ -637,17 +640,24 @@ def collect_tensor_buffers( missing_inputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(inputs, input_idx))] missing_outputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(outputs, output_idx))] - assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}' - - if len(missing_inputs) != 0: - warnings.warn(f'Some input nodes are missing: {missing_inputs}, will try to add them into graph') - for name in missing_inputs: - tensor = self.tensor_map[name] - tensor.index = tensor_idx - tensor_idx += 1 - tensors.append(tensor) - item_idx = inputs.index(name) - input_idx[item_idx] = tensor.index + if not self.missing_outputs_as_constants: + assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}' + + missing_vars_dict = { + 'input': (missing_inputs, inputs, input_idx), + 'output': (missing_outputs, outputs, output_idx), + } + + for key, (missing_vars, var_indices, out_indices) in missing_vars_dict.items(): + if len(missing_vars) != 0: + warnings.warn(f'Some {key} nodes are missing: {missing_vars}, will try to add them into graph') + for name in missing_vars: + tensor = self.tensor_map[name] + tensor.index = tensor_idx + tensor_idx += 1 + tensors.append(tensor) + item_idx = var_indices.index(name) + out_indices[item_idx] = tensor.index return tensors, buffers, input_idx, output_idx