Skip to content

Commit

Permalink
Fix or ignore some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662500667
  • Loading branch information
frigus02 authored and tfx-copybara committed Aug 13, 2024
1 parent 90c8da3 commit 6bd59c6
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
8 changes: 4 additions & 4 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def add_downstream_artifact(
"""Adds a downstream artifact to the ModelRelations."""
artifact_type_name = downstream_artifact.type
if _is_eval_blessed(artifact_type_name, downstream_artifact):
self.model_blessing_artifacts.append(downstream_artifact)
self.model_blessing_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type

elif _is_infra_blessed(artifact_type_name, downstream_artifact):
self.infra_blessing_artifacts.append(downstream_artifact)
self.infra_blessing_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type

elif artifact_type_name == ops_utils.MODEL_PUSH_TYPE_NAME:
self.model_push_artifacts.append(downstream_artifact)
self.model_push_artifacts.append(downstream_artifact) # pytype: disable=container-type-mismatch # dont-delete-module-type

def meets_policy(self, policy: Policy) -> bool:
"""Checks if ModelRelations contains artifacts that meet the Policy."""
Expand Down Expand Up @@ -486,7 +486,7 @@ def event_filter(event):
]
# Set `max_num_hops` to 50, which should be enough for this use case.
batch_downstream_artifacts_and_types_by_model_identifier = (
mlmd_resolver.get_downstream_artifacts_by_artifacts(
mlmd_resolver.get_downstream_artifacts_by_artifacts( # pytype: disable=wrong-arg-types # dont-delete-module-type
batch_model_artifacts,
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
Expand Down
12 changes: 6 additions & 6 deletions tfx/dsl/input_resolution/ops/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def create_examples(
)
self.put_execution(
'ExampleGen',
inputs={},
inputs={}, # pytype: disable=wrong-arg-types # dont-delete-module-type
outputs={'examples': self.unwrap_tfx_artifacts(examples)},
contexts=contexts,
connection_config=connection_config,
Expand All @@ -275,7 +275,7 @@ def transform_examples(
)
self.put_execution(
'Transform',
inputs=inputs,
inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type
outputs={
'transform_graph': self.unwrap_tfx_artifacts([transform_graph])
},
Expand All @@ -298,7 +298,7 @@ def train_on_examples(
inputs['transform_graph'] = self.unwrap_tfx_artifacts([transform_graph])
self.put_execution(
'TFTrainer',
inputs=inputs,
inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type
outputs={'model': self.unwrap_tfx_artifacts([model])},
contexts=contexts,
connection_config=connection_config,
Expand All @@ -325,7 +325,7 @@ def evaluator_bless_model(

self.put_execution(
'Evaluator',
inputs=inputs,
inputs=inputs, # pytype: disable=wrong-arg-types # dont-delete-module-type
outputs={'blessing': self.unwrap_tfx_artifacts([model_blessing])},
contexts=contexts,
connection_config=connection_config,
Expand Down Expand Up @@ -353,7 +353,7 @@ def infra_validator_bless_model(

self.put_execution(
'InfraValidator',
inputs={'model': self.unwrap_tfx_artifacts([model])},
inputs={'model': self.unwrap_tfx_artifacts([model])}, # pytype: disable=wrong-arg-types # dont-delete-module-type
outputs={'result': self.unwrap_tfx_artifacts([model_infra_blessing])},
contexts=contexts,
connection_config=connection_config,
Expand All @@ -375,7 +375,7 @@ def push_model(
)
self.put_execution(
'ServomaticPusher',
inputs={'model_export': self.unwrap_tfx_artifacts([model])},
inputs={'model_export': self.unwrap_tfx_artifacts([model])}, # pytype: disable=wrong-arg-types # dont-delete-module-type
outputs={'model_push': self.unwrap_tfx_artifacts([model_push])},
contexts=contexts,
connection_config=connection_config,
Expand Down
2 changes: 1 addition & 1 deletion tfx/orchestration/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def get_published_artifacts_by_type_within_context(
@staticmethod
def _get_legacy_producer_component_id(
execution: metadata_store_pb2.Execution) -> str:
return execution.properties[_EXECUTION_TYPE_KEY_COMPONENT_ID].string_value
return execution.properties[_EXECUTION_TYPE_KEY_COMPONENT_ID].string_value # pytype: disable=bad-return-type # dont-delete-module-type

def get_qualified_artifacts(
self,
Expand Down
2 changes: 1 addition & 1 deletion tfx/orchestration/portable/importer_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _extract_proto_map(
extract_mlmd_value = lambda v: getattr(v, v.WhichOneof('value'))
return {k: extract_mlmd_value(v.field_value) for k, v in proto_map.items()}

def run(
def run( # pytype: disable=signature-mismatch # dont-delete-module-type
self, mlmd_connection: metadata.Metadata,
pipeline_node: pipeline_pb2.PipelineNode,
pipeline_info: pipeline_pb2.PipelineInfo,
Expand Down
2 changes: 1 addition & 1 deletion tfx/orchestration/portable/partial_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def _get_base_pipeline_run_context(
pipeline_run_contexts, key=lambda c: c.create_time_since_epoch
)
if not sorted_run_contexts:
return None
return None # pytype: disable=bad-return-type # dont-delete-module-type

logging.info(
'base_run_id not provided. Default to latest pipeline run: %s',
Expand Down
2 changes: 1 addition & 1 deletion tfx/orchestration/portable/resolver_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _extract_proto_map(
extract_mlmd_value = lambda v: getattr(v, v.WhichOneof('value'))
return {k: extract_mlmd_value(v.field_value) for k, v in proto_map.items()}

def run(
def run( # pytype: disable=signature-mismatch # dont-delete-module-type
self, mlmd_connection: metadata.Metadata,
pipeline_node: pipeline_pb2.PipelineNode,
pipeline_info: pipeline_pb2.PipelineInfo,
Expand Down

0 comments on commit 6bd59c6

Please sign in to comment.