Skip to content

Commit

Permalink
working implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
khsrali committed Oct 2, 2024
1 parent 3cb8539 commit 75bb430
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 43 deletions.
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ dependencies:
- importlib-metadata~=6.0
- numpy~=1.21
- paramiko~=3.0
- plumpy~=0.22.3
- pgsu~=0.3.0
- psutil~=5.6
- psycopg[binary]~=3.0
Expand All @@ -35,3 +34,5 @@ dependencies:
- tqdm~=4.45
- upf_to_json~=0.9.2
- wrapt~=1.11
- pip:
- plumpy@git+https://github.com/aiidateam/plumpy.git@force-kill#egg=plumpy
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
'importlib-metadata~=6.0',
'numpy~=1.21',
'paramiko~=3.0',
'plumpy~=0.22.3',
'plumpy@git+https://github.com/aiidateam/plumpy.git@force-kill#egg=plumpy',
'pgsu~=0.3.0',
'psutil~=5.6',
'psycopg[binary]~=3.0',
Expand Down
16 changes: 11 additions & 5 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ def process_status(call_link_label, most_recent_node, max_depth, processes):
@options.ALL(help='Kill all processes if no specific processes are specified.')
@options.TIMEOUT()
@options.WAIT()
@options.FORCE_KILL()
@options.FORCE_KILL(
help='Force kill the process if it does not respond to the initial kill signal.\n'
' Note: This may lead to orphaned jobs on your HPC and should be used with caution.'
)
@decorators.with_dbenv()
def process_kill(processes, all_entries, timeout, wait, force_kill):
"""Kill running processes.
Expand All @@ -341,10 +344,13 @@ def process_kill(processes, all_entries, timeout, wait, force_kill):

with capture_logging() as stream:
try:
message = 'Killed through `verdi process kill`'
control.kill_processes(
processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message, force_kill=force_kill
)
if force_kill:
echo.echo_warning('Force kill is enabled. This may lead to orphaned jobs on your HPC.')
# note: It's important to include -F in the message, as this is used to identify force-killed processes.
message = 'Force killed through `verdi process kill -F`'
else:
message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down
4 changes: 2 additions & 2 deletions src/aiida/cmdline/params/options/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ def set_log_level(ctx, _param, value):
FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.')

FORCE_KILL = OverridableOption(
'-fk',
'-F',
'--force-kill',
is_flag=True,
default=False,
help='Kills the process without waiting for a response if the job is killed.',
help='Kills the process without waiting for a confirmation if the job has been killed from remote.',
)

SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.')
Expand Down
49 changes: 29 additions & 20 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ async def do_upload():

try:
logger.info(f'scheduled request to upload CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, PreSubmitException, plumpy.process_states.Interruption)
skip_submit = await exponential_backoff_retry(
do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_upload, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except PreSubmitException:
raise
Expand Down Expand Up @@ -149,9 +149,9 @@ async def do_submit():

try:
logger.info(f'scheduled request to submit CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_submit, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -207,9 +207,9 @@ async def do_update():

try:
logger.info(f'scheduled request to update CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
job_done = await exponential_backoff_retry(
do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_update, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -258,9 +258,9 @@ async def do_monitor():

try:
logger.info(f'scheduled request to monitor CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
monitor_result = await exponential_backoff_retry(
do_monitor, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_monitor, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -334,9 +334,9 @@ async def do_retrieve():

try:
logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>')
ignore_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
breaking_exceptions = (plumpy.futures.CancelledError, plumpy.process_states.Interruption)
result = await exponential_backoff_retry(
do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions
do_retrieve, initial_interval, max_attempts, logger=node.logger, breaking_exceptions=breaking_exceptions
)
except (plumpy.futures.CancelledError, plumpy.process_states.Interruption):
raise
Expand Down Expand Up @@ -385,7 +385,7 @@ async def do_stash():
initial_interval,
max_attempts,
logger=node.logger,
ignore_exceptions=plumpy.process_states.Interruption,
breaking_exceptions=plumpy.process_states.Interruption,
)
except plumpy.process_states.Interruption:
raise
Expand All @@ -398,7 +398,9 @@ async def do_stash():
return


async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture):
async def task_kill_job(
node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture, force_kill: bool = False
):
"""Transport task that will attempt to kill a job calculation.
The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager
Expand All @@ -412,7 +414,6 @@ async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, canc
:raises: TransportTaskException if after the maximum number of retries the transport task still excepted
"""
breakpoint()
initial_interval = get_config_option(RETRY_INTERVAL_OPTION)
max_attempts = get_config_option(MAX_ATTEMPTS_OPTION)

Expand All @@ -423,19 +424,23 @@ async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, canc
authinfo = node.get_authinfo()

async def do_kill():
# this function fails when there is no transport
# then the exponential backof raises an exception
with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
return execmanager.kill_calculation(node, transport)

if force_kill:
logger.warning(f'Process<{node.pk}> has been force killed! this may result in orphaned jobs.')
raise plumpy.process_states.ForceKillInterruption('Force killing CalcJob')
try:
logger.info(f'scheduled request to kill CalcJob<{node.pk}>')
result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger)
# Note: any exception raised here, will result in the process being excepted. not killed!
# There for it can result in orphaned jobs!
except plumpy.process_states.Interruption:
logger.warning(f'killing CalcJob<{node.pk}> excepted, the job might be orphaned.')
raise
except Exception as exception:
logger.warning(f'killing CalcJob<{node.pk}> failed')
logger.warning(f'killing CalcJob<{node.pk}> excepted, the job might be orphaned.')
raise TransportTaskException(f'kill_calculation failed {max_attempts} times consecutively') from exception
else:
logger.info(f'killing CalcJob<{node.pk}> successful')
Expand Down Expand Up @@ -531,7 +536,7 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
monitor_result = await self._monitor_job(node, transport_queue, self.monitors)

if monitor_result and monitor_result.action is CalcJobMonitorAction.KILL:
await self._kill_job(node, transport_queue)
await self._kill_job(node, transport_queue, force_kill=False)
job_done = True

if monitor_result and not monitor_result.retrieve:
Expand Down Expand Up @@ -570,7 +575,11 @@ async def execute(self) -> plumpy.process_states.State: # type: ignore[override
except TransportTaskException as exception:
raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}')
except plumpy.process_states.KillInterruption as exception:
await self._kill_job(node, transport_queue)
await self._kill_job(node, transport_queue, force_kill=False)
node.set_process_status(str(exception))
return self.retrieve(monitor_result=self._monitor_result)
except plumpy.process_states.ForceKillInterruption as exception:
await self._kill_job(node, transport_queue, force_kill=True)
node.set_process_status(str(exception))
return self.retrieve(monitor_result=self._monitor_result)
except (plumpy.futures.CancelledError, asyncio.CancelledError):
Expand Down Expand Up @@ -610,9 +619,9 @@ async def _monitor_job(self, node, transport_queue, monitors) -> CalcJobMonitorR

return monitor_result

async def _kill_job(self, node, transport_queue) -> None:
async def _kill_job(self, node, transport_queue, force_kill) -> None:
"""Kill the job."""
await self._launch_task(task_kill_job, node, transport_queue)
await self._launch_task(task_kill_job, node, transport_queue, force_kill=force_kill)
if self._killing is not None:
self._killing.set_result(True)
else:
Expand Down
15 changes: 7 additions & 8 deletions src/aiida/engine/processes/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def pause_processes(
.. note:: Requires the daemon to be running, or processes will be unresponsive.
:param processes: List of processes to play.
:param processes: List of processes to pause.
:param all_entries: Pause all playing processes.
:param timeout: Raise a ``ProcessTimeoutException`` if the process does not respond within this amount of seconds.
:param wait: Set to ``True`` to wait for process response, for ``False`` the action is fire-and-forget.
Expand Down Expand Up @@ -174,7 +174,6 @@ def kill_processes(
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
force_kill: bool = False,
) -> None:
"""Kill running processes.
Expand All @@ -184,7 +183,6 @@ def kill_processes(
:param all_entries: Kill all active processes.
:param timeout: Raise a ``ProcessTimeoutException`` if the process does not respond within this amount of seconds.
:param wait: Set to ``True`` to wait for process response, for ``False`` the action is fire-and-forget.
:param force_kill: ... TODO
:raises ``ProcessTimeoutException``: If the processes do not respond within the timeout.
"""
if not get_daemon_client().is_daemon_running:
Expand All @@ -201,9 +199,7 @@ def kill_processes(
return

controller = get_manager().get_process_controller()
_perform_actions(
processes, controller.kill_process, 'kill', 'killing', timeout, wait, msg=message, force_kill=force_kill
)
_perform_actions(processes, controller.kill_process, 'kill', 'killing', timeout, wait, msg=message)


def _perform_actions(
Expand Down Expand Up @@ -283,9 +279,12 @@ def handle_result(result):
try:
# unwrap is need here since LoopCommunicator will also wrap a future
unwrapped = unwrap_kiwi_future(future)
result = unwrapped.result()
result = unwrapped.result(timeout=timeout)
except communications.TimeoutError:
LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out')
if process.is_terminated:
LOGGER.report(f'request to {infinitive} Process<{process.pk}> sent')
else:
LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out')
except Exception as exception:
LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}')
else:
Expand Down
12 changes: 6 additions & 6 deletions src/aiida/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def with_interrupt(self, coro: Awaitable[Any]) -> Any:
import asyncio
loop = asyncio.get_event_loop()
interruptable = InterutableFuture()
interruptable = InterruptableFuture()
loop.call_soon(interruptable.interrupt, RuntimeError("STOP"))
loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(2.)))
>>> RuntimeError: STOP
Expand All @@ -124,7 +124,7 @@ def interruptable_task(
) -> InterruptableFuture:
"""Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it.
:param coro: the coroutine that should be made interruptable with object of InterutableFuture as last paramenter
:param coro: the coroutine that should be made interruptable with object of InterruptableFuture as last parameter
:param loop: the event loop in which to run the coroutine, by default uses asyncio.get_event_loop()
:return: an InterruptableFuture
"""
Expand Down Expand Up @@ -178,7 +178,7 @@ async def exponential_backoff_retry(
initial_interval: Union[int, float] = 10.0,
max_attempts: int = 5,
logger: Optional[logging.Logger] = None,
ignore_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None,
breaking_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None,
) -> Any:
"""Coroutine to call a function, recalling it with an exponential backoff in the case of an exception
Expand All @@ -190,7 +190,8 @@ async def exponential_backoff_retry(
:param fct: the function to call, which will be turned into a coroutine first if it is not already
:param initial_interval: the time to wait after the first caught exception before calling the coroutine again
:param max_attempts: the maximum number of times to call the coroutine before re-raising the exception
:param ignore_exceptions: exceptions to ignore, i.e. when caught do nothing and simply re-raise
:param breaking_exceptions: exceptions that breaks EBM loop. These exceptions are re-raise.
If None, all exceptions are raised only after max_attempts reached.
:return: result if the ``coro`` call completes within ``max_attempts`` retries without raising
"""
if logger is None:
Expand All @@ -205,8 +206,7 @@ async def exponential_backoff_retry(
result = await coro()
break # Finished successfully
except Exception as exception:
# Re-raise exceptions that should be ignored
if ignore_exceptions is not None and isinstance(exception, ignore_exceptions):
if breaking_exceptions is not None and isinstance(exception, breaking_exceptions):
raise

count = iteration + 1
Expand Down

0 comments on commit 75bb430

Please sign in to comment.