Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move tensorflow/tsl/protobuf to xla/tsl/protobuf #6842

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
'the TF version doesn\'t match with the TF Serving version. '
'We will try importing again with a workaround:%s', err)
from tensorflow.core.protobuf import error_codes_pb2 as old_error_codes_pb2
from tensorflow.tsl.protobuf import error_codes_pb2 as new_error_codes_pb2
from tensorflow.compiler.xla.tsl.protobuf import error_codes_pb2 as new_error_codes_pb2

old_error_codes_pb2._CODE = new_error_codes_pb2._CODE # pylint: disable=protected-access # pytype: disable=module-attr

# Retry.
Expand Down
20 changes: 0 additions & 20 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from tfx.dsl.io import filesystem
from tfx.orchestration import metadata
from tfx.orchestration import node_proto_view
from tfx.orchestration import subpipeline_utils
from tfx.orchestration.experimental.core import async_pipeline_task_gen
from tfx.orchestration.experimental.core import constants
from tfx.orchestration.experimental.core import env
Expand Down Expand Up @@ -1253,20 +1252,6 @@ def filter_by_pipeline_uid(
return lambda p: p.pipeline_uid == pipeline_uid


def _record_orchestration_time(pipeline_state: pstate.PipelineState) -> None:
"""Records an orchestration time for the pipeline run."""
# We only care about orchestration time for root pipelines, skip any
# subpipelines.
if subpipeline_utils.is_subpipeline(pipeline_state.pipeline):
return
pipeline_run_id = pipeline_state.pipeline_run_id
# Backend expects an empty string for the pipeline run id, for ASYNC pipeline
# runs.
if pipeline_run_id is None:
pipeline_run_id = ''
env.get_env().record_orchestration_time(pipeline_run_id)


@_pipeline_op()
def orchestrate(
mlmd_connection_manager: mlmd_cm.MLMDConnectionManager,
Expand Down Expand Up @@ -1337,7 +1322,6 @@ def orchestrate(
service_job_manager,
pipeline_state,
)
_record_orchestration_time(pipeline_state)
except Exception: # pylint: disable=broad-except
# If orchestrating a stop-initiated pipeline raises an exception, we log
# the exception but do not re-raise since we do not want to crash the
Expand All @@ -1361,7 +1345,6 @@ def orchestrate(
service_job_manager,
pipeline_state,
)
_record_orchestration_time(pipeline_state)
except Exception as e: # pylint: disable=broad-except
logging.exception(
'Exception raised while orchestrating update-initiated pipeline %s',
Expand All @@ -1381,7 +1364,6 @@ def orchestrate(
),
)
)
_record_orchestration_time(pipeline_state)
except Exception: # pylint: disable=broad-except
# If stop initiation also raised an exception , we log the exception but
# do not re-raise since we do not want to crash the orchestrator. If
Expand All @@ -1405,7 +1387,6 @@ def orchestrate(
service_job_manager,
pipeline_state,
)
_record_orchestration_time(pipeline_state)
except Exception as e: # pylint: disable=broad-except
logging.exception(
'Exception raised while orchestrating active pipeline %s',
Expand All @@ -1423,7 +1404,6 @@ def orchestrate(
message=f'Error orchestrating active pipeline: {str(e)}',
)
)
_record_orchestration_time(pipeline_state)
except Exception: # pylint: disable=broad-except
# If stop initiation also raised an exception , we log the exception but
# do not re-raise since we do not want to crash the orchestrator. If
Expand Down
103 changes: 2 additions & 101 deletions tfx/orchestration/experimental/core/pipeline_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,73 +1434,10 @@ def test_stop_node_wait_for_inactivation_timeout(self):
(pstate.NodeState.STOPPING, pstate.NodeState.STOPPED),
)

@parameterized.named_parameters(
dict(
testcase_name='async',
pipeline=_test_pipeline('pipeline1'),
expected_run_id='',
),
dict(
testcase_name='sync',
pipeline=_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC),
expected_run_id='run0',
),
)
def test_record_orchestration_time(self, pipeline, expected_run_id):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline_ops.initiate_pipeline_start(m, pipeline)
environment = env.get_env()
with mock.patch.object(
environment,
'record_orchestration_time',
wraps=environment.record_orchestration_time,
) as mock_env_record_orchestration_time:
task_queue = tq.TaskQueue()
pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
self._mock_service_job_manager,
)
mock_env_record_orchestration_time.assert_called_with(expected_run_id)

def test_record_orchestration_time_subpipeline(self):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
runtime_parameter_utils.substitute_runtime_parameter(
pipeline,
{
constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run0',
},
)
pipeline_ops.initiate_pipeline_start(m, pipeline)
environment = env.get_env()
with mock.patch.object(
environment,
'record_orchestration_time',
wraps=environment.record_orchestration_time,
) as mock_env_record_orchestration_time:
task_queue = tq.TaskQueue()
pipeline_ops.orchestrate(
mlmd_connection_manager,
task_queue,
self._mock_service_job_manager,
)
mock_env_record_orchestration_time.assert_called_with('run0')

@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
@mock.patch.object(
pipeline_ops,
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
def test_orchestrate_active_pipelines(
self,
mock_record_orchestration_time,
mock_async_task_gen,
mock_sync_task_gen,
self, mock_async_task_gen, mock_sync_task_gen
):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
Expand Down Expand Up @@ -1572,15 +1509,6 @@ def test_orchestrate_active_pipelines(
service_jobs.DummyServiceJobManager(),
)

# Check that the orchestration time was recorded four times. Once for each
# of the four pipelines.
mock_record_orchestration_time.assert_has_calls([
mock.call(mock.ANY),
mock.call(mock.ANY),
mock.call(mock.ANY),
mock.call(mock.ANY),
])

self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count)
self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count)

Expand Down Expand Up @@ -1622,15 +1550,9 @@ def test_orchestrate_active_pipelines(
@mock.patch.object(
task_gen_utils, 'generate_cancel_task_from_running_execution'
)
@mock.patch.object(
pipeline_ops,
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
def test_orchestrate_stop_initiated_pipelines(
self,
pipeline,
mock_record_orchestration_time,
mock_gen_task_from_active,
mock_async_task_gen,
mock_sync_task_gen,
Expand Down Expand Up @@ -1695,10 +1617,6 @@ def recorder(event):
self._mock_service_job_manager,
)
)
# We should have recorded the orchestration time once, for one pipeline.
# We reset after to verify this is true throughout.
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()

# PipelineFinished event should not trigger since not all the nodes are
# stopped.
Expand Down Expand Up @@ -1765,8 +1683,6 @@ def recorder(event):
self._mock_service_job_manager,
)
)
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()
self.assertTrue(task_queue.is_empty())
[execution] = m.store.get_executions_by_id([pipeline_execution_id])
self.assertEqual(
Expand Down Expand Up @@ -1805,7 +1721,6 @@ def recorder(event):
self._mock_service_job_manager,
)
)
mock_record_orchestration_time.assert_not_called()

@mock.patch.object(
task_gen_utils, 'generate_cancel_task_from_running_execution'
Expand Down Expand Up @@ -1975,14 +1890,7 @@ def recorder(event):
_test_pipeline('pipeline1'),
_test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC),
)
@mock.patch.object(
pipeline_ops,
'_record_orchestration_time',
wraps=pipeline_ops._record_orchestration_time,
)
def test_orchestrate_update_initiated_pipelines(
self, pipeline, mock_record_orchestration_time
):
def test_orchestrate_update_initiated_pipelines(self, pipeline):
with self._mlmd_cm as mlmd_connection_manager:
m = mlmd_connection_manager.primary_mlmd_handle
pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
Expand Down Expand Up @@ -2016,10 +1924,6 @@ def test_orchestrate_update_initiated_pipelines(
pipeline_ops.orchestrate(
mlmd_connection_manager, task_queue, self._mock_service_job_manager
)
# We should have recorded the orchestration time once, for one pipeline.
# We reset after to verify this is true throughout.
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()
# stop_node_services should be called for ExampleGen.
self._mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'ExampleGen')]
Expand Down Expand Up @@ -2050,9 +1954,6 @@ def test_orchestrate_update_initiated_pipelines(
self._mock_service_job_manager.stop_node_services.assert_has_calls(
[mock.call(mock.ANY, 'Transform')]
)
# Check that the orchestration time was recorded again.
mock_record_orchestration_time.assert_called_once()
mock_record_orchestration_time.reset_mock()

# Check that the node states are STARTING.
[execution] = m.store.get_executions_by_id([pipeline_state.execution_id])
Expand Down