Skip to content

Commit

Permalink
[converter] add missing_outputs_as_constants (#290)
Browse files Browse the repository at this point in the history
* [converter] add missing_outputs_as_constants

* minor fixes

* lint fixes
  • Loading branch information
peterjc123 authored Apr 8, 2024
1 parent 1fd562e commit 6b62859
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
6 changes: 3 additions & 3 deletions tests/converter_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ def forward(self, x):
tflite.BuiltinOperator.CONCATENATION,
)
self.assertEqual(tfl_model.Subgraphs(0).Operators(5).OutputsLength(), 1)

def test_fuse_transposeconv_relu(self):
class TestModel(nn.Module):
def __init__(self) -> None:
Expand All @@ -1394,7 +1394,7 @@ def __init__(self) -> None:

def forward(self, x):
y = self.act(self.transposeconv(x))
return y
return y

model = TestModel()
model.eval()
Expand All @@ -1403,7 +1403,7 @@ def forward(self, x):

converter = TFLiteConverter(model, dummy_input, model_path)
converter.convert()

tfl_model = parse_model(model_path)
self.assertEqual(tfl_model.OperatorCodesLength(), 1)
self.assertEqual(tfl_model.OperatorCodes(0).DeprecatedBuiltinCode(), tflite.BuiltinOperator.TRANSPOSE_CONV)
Expand Down
4 changes: 3 additions & 1 deletion tinynn/converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
hybrid_gen_single_op_models: bool = False,
hybrid_config: typing.Optional[typing.Dict[str, bool]] = None,
group_tensors: bool = False,
missing_outputs_as_constants: bool = False,
) -> None:
""" The TFLiteConverter class
Expand Down Expand Up @@ -110,14 +111,15 @@ def __init__(
hybrid_gen_single_op_models: Generate both floating point and quantized version of the model for hybrid \
quantizable ops. Defaults to False
group_tensors (bool): Group tensors to save space. Defaults to False
missing_outputs_as_constants (bool): View missing outputs as constants. Defaults to False
"""

self.model = model
self.lower_model = None
self.graph = None
self.tensor_map = {}
self.tensor_map_copies = {}
self.common_graph = CommonGraph()
self.common_graph = CommonGraph(missing_outputs_as_constants)

if type(dummy_input) in (tuple, list):
self.dummy_input = dummy_input
Expand Down
34 changes: 22 additions & 12 deletions tinynn/converter/operators/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class CommonGraph(object):
input_transpose: typing.List[bool]
output_transpose: typing.Union[typing.List[typing.Optional[bool]], typing.Optional[bool]]
node_op_counter: int
missing_outputs_as_constants: bool

def __init__(self) -> None:
def __init__(self, missing_outputs_as_constants: bool) -> None:
self.graph = ig.Graph(directed=True)
self.tensor_map = dict()
self.tensor_node_map = dict()
Expand All @@ -39,6 +40,8 @@ def __init__(self) -> None:
self.transform_store = {}
self.constant_mapping = {}

self.missing_outputs_as_constants = missing_outputs_as_constants

def add_transform_store(self, tensor_name: str, transform_name: str, new_tensor_name: str):
self.transform_store.setdefault(tensor_name, {})
self.transform_store[tensor_name][transform_name] = new_tensor_name
Expand Down Expand Up @@ -637,17 +640,24 @@ def collect_tensor_buffers(
missing_inputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(inputs, input_idx))]
missing_outputs = [name for name, _ in filter(lambda x: x[1] < 0, zip(outputs, output_idx))]

assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}'

if len(missing_inputs) != 0:
warnings.warn(f'Some input nodes are missing: {missing_inputs}, will try to add them into graph')
for name in missing_inputs:
tensor = self.tensor_map[name]
tensor.index = tensor_idx
tensor_idx += 1
tensors.append(tensor)
item_idx = inputs.index(name)
input_idx[item_idx] = tensor.index
if not self.missing_outputs_as_constants:
assert len(missing_outputs) == 0, f'Some output nodes are missing: {missing_outputs}'

missing_vars_dict = {
'input': (missing_inputs, inputs, input_idx),
'output': (missing_outputs, outputs, output_idx),
}

for key, (missing_vars, var_indices, out_indices) in missing_vars_dict.items():
if len(missing_vars) != 0:
warnings.warn(f'Some {key} nodes are missing: {missing_vars}, will try to add them into graph')
for name in missing_vars:
tensor = self.tensor_map[name]
tensor.index = tensor_idx
tensor_idx += 1
tensors.append(tensor)
item_idx = var_indices.index(name)
out_indices[item_idx] = tensor.index

return tensors, buffers, input_idx, output_idx

Expand Down

0 comments on commit 6b62859

Please sign in to comment.