Source code for aiida_cp2k.workchains.aiida_base_restart

# -*- coding: utf-8 -*-
# pylint: disable=inconsistent-return-statements,no-member
"""Base implementation of `WorkChain` class that implements a simple automated restart mechanism for calculations."""
from __future__ import absolute_import

from collections import namedtuple

from aiida import orm
from aiida.orm import Dict
from aiida.common import exceptions, AttributeDict, AiidaException
from aiida.common.lang import override
from aiida.engine import CalcJob, WorkChain, ToContext, append_, ExitCode
from aiida.plugins.entry_point import get_entry_point_names, load_entry_point


[docs]class UnexpectedCalculationFailure(AiidaException): """Raised when a calculation job has failed for an unexpected or unrecognized reason."""
ErrorHandlerReport = namedtuple('ErrorHandlerReport', 'is_handled do_break exit_code') ErrorHandlerReport.__new__.__defaults__ = (False, False, ExitCode()) """ A namedtuple to define an error handler report for a :class:`~aiida.engine.processes.workchains.workchain.WorkChain`. This namedtuple should be returned by an error handling method of a workchain instance if the condition of the error handling was met by the failure mode of the calculation. If the error was appriopriately handled, the 'is_handled' field should be set to `True`, and `False` otherwise. If no further error handling should be performed after this method the 'do_break' field should be set to `True` :param is_handled: boolean, set to `True` when an error was handled, default is `False` :param do_break: boolean, set to `True` if no further error handling should be performed, default is `False` :param exit_code: an instance of the :class:`~aiida.engine.processes.exit_code.ExitCode` tuple """
[docs]def prepare_process_inputs(process, inputs): """Prepare the inputs for submission for the given process, according to its spec. That is to say that when an input is found in the inputs that corresponds to an input port in the spec of the process that expects a `Dict`, yet the value in the inputs is a plain dictionary, the value will be wrapped in by the `Dict` class to create a valid input. :param process: sub class of `Process` for which to prepare the inputs dictionary :param inputs: a dictionary of inputs intended for submission of the process :return: a dictionary with all bare dictionaries wrapped in `Dict` if dictated by the process spec """ prepared_inputs = wrap_bare_dict_inputs(process.spec().inputs, inputs) return AttributeDict(prepared_inputs)
[docs]def wrap_bare_dict_inputs(port_namespace, inputs): """Wrap bare dictionaries in `inputs` in a `Dict` node if dictated by the corresponding port in given namespace. :param port_namespace: a `PortNamespace` :param inputs: a dictionary of inputs intended for submission of the process :return: a dictionary with all bare dictionaries wrapped in `Dict` if dictated by the port namespace """ from aiida.engine.processes import PortNamespace wrapped = {} for key, value in inputs.items(): if key not in port_namespace: wrapped[key] = value continue port = port_namespace[key] if isinstance(port, PortNamespace): wrapped[key] = wrap_bare_dict_inputs(port, value) elif port.valid_type == Dict and isinstance(value, dict): wrapped[key] = Dict(dict=value) else: wrapped[key] = value return wrapped
[docs]class BaseRestartWorkChain(WorkChain): """Base restart work chain This work chain serves as the starting point for more complex work chains that will be designed to run a calculation that might need multiple restarts to come to a successful end. These restarts may be necessary because a single calculation run is not sufficient to achieve a fully converged result, or certain errors maybe encountered which are recoverable. This work chain implements the most basic functionality to achieve this goal. It will launch calculations, restarting until it is completed successfully or the maximum number of iterations is reached. It can recover from errors through error handlers that can be attached dynamically through the `register_error_handler` decorator. The idea is to sub class this work chain and leverage the generic error handling that is implemented in the few outline methods. The minimally required outline would look something like the following:: cls.setup while_(cls.should_run_calculation)( cls.run_calculation, cls.inspect_calculation, ) Each of these methods can of course be overriden but they should be general enough to fit most calculation cycles. The `run_calculation` method will take the inputs for the calculation process from the context under the key `inputs`. The user should therefore make sure that before the `run_calculation` method is called, that the to be used inputs are stored under `self.ctx.inputs`. One can update the inputs based on the results from a prior calculation by calling an outline method just before the `run_calculation` step, for example:: cls.setup while_(cls.should_run_calculation)( cls.prepare_calculation, cls.run_calculation, cls.inspect_calculation, ) Where in the `prepare_calculation` method, the inputs dictionary at `self.ctx.inputs` is updated before the next calculation will be run with those inputs. The `_calculation_class` attribute should be set to the `CalcJob` class that should be run in the loop. """ _verbose = False _calculation_class = None _error_handler_entry_point = None
[docs] def __init__(self, *args, **kwargs): super(BaseRestartWorkChain, self).__init__(*args, **kwargs) if self._calculation_class is None or not issubclass(self._calculation_class, CalcJob): raise ValueError('no valid CalcJob class defined for `_calculation_class` attribute') self._load_error_handlers()
[docs] @override def load_instance_state(self, saved_state, load_context): super(BaseRestartWorkChain, self).load_instance_state(saved_state, load_context) self._load_error_handlers()
[docs] def _load_error_handlers(self): """If an error handler entry point is defined, load them. If the plugin cannot be loaded log it and pass.""" if self._error_handler_entry_point is not None: for entry_point_name in get_entry_point_names(self._error_handler_entry_point): try: load_entry_point(self._error_handler_entry_point, entry_point_name) self.logger.info("loaded the '%s' entry point for the '%s' error handlers category", entry_point_name, self._error_handler_entry_point) except exceptions.EntryPointError as exception: self.logger.warning("failed to load the '%s' entry point for the '%s' error handlers: %s", entry_point_name, self._error_handler_entry_point, exception)
[docs] @classmethod def define(cls, spec): # yapf: disable # pylint: disable=bad-continuation super(BaseRestartWorkChain, cls).define(spec) spec.input('max_iterations', valid_type=orm.Int, default=lambda: orm.Int(5), help='Maximum number of iterations the work chain will restart the calculation to finish successfully.') spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False), help='If `True`, work directories of all called calculation will be cleaned at the end of execution.') spec.exit_code(101, 'ERROR_MAXIMUM_ITERATIONS_EXCEEDED', message='The maximum number of iterations was exceeded.') spec.exit_code(102, 'ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE', message='The calculation failed for an unknown reason, twice in a row.')
[docs] def setup(self): """Initialize context variables that are used during the logical flow of the `BaseRestartWorkChain`.""" self.ctx.calc_name = self._calculation_class.__name__ self.ctx.unexpected_failure = False self.ctx.restart_calc = None self.ctx.is_finished = False self.ctx.iteration = 0
[docs] def should_run_calculation(self): """Return whether a new calculation should be run. This is the case as long as the last calculation has not finished successfully and the maximum number of restarts has not yet been exceeded. """ return not self.ctx.is_finished and self.ctx.iteration < self.inputs.max_iterations.value
[docs] def run_calculation(self): """Run the next calculation, taking the input dictionary from the context at `self.ctx.inputs`.""" self.ctx.iteration += 1 try: unwrapped_inputs = self.ctx.inputs except AttributeError: raise AttributeError('no calculation input dictionary was defined in `self.ctx.inputs`') # Set the `CALL` link label unwrapped_inputs['metadata']['call_link_label'] = 'iteration_{:02d}'.format(self.ctx.iteration) inputs = prepare_process_inputs(self._calculation_class, unwrapped_inputs) calculation = self.submit(self._calculation_class, **inputs) # Add a new empty list to the `errors_handled` extra. If any errors handled registered through the # `register_error_handler` decorator return an `ErrorHandlerReport`, their name will be appended to that list. errors_handled = self.node.get_extra('errors_handled', []) errors_handled.append([]) self.node.set_extra('errors_handled', errors_handled) self.report('launching {}<{}> iteration #{}'.format(self.ctx.calc_name, calculation.pk, self.ctx.iteration)) return ToContext(calculations=append_(calculation))
[docs] def inspect_calculation(self): """Analyse the results of the previous calculation and call the error handlers when necessary.""" calculation = self.ctx.calculations[self.ctx.iteration - 1] # Done: successful completion of last calculation if calculation.is_finished_ok: # Perform an optional sanity check. If it returns an `ExitCode` this means an unrecoverable situation was # detected and the work chain should be aborted. If it returns `False`, the sanity check detected a problem # but has handled the problem and we should restart the cycle. handler = self._handle_calculation_sanity_checks(calculation) # pylint: disable=assignment-from-no-return if isinstance(handler, ErrorHandlerReport) and handler.exit_code.status != 0: # Sanity check returned a handler with an exit code that is non-zero, so we abort self.report('{}<{}> finished successfully, but sanity check detected unrecoverable problem'.format( self.ctx.calc_name, calculation.pk)) return handler.exit_code if isinstance(handler, ErrorHandlerReport): # Reset the `unexpected_failure` since we are restarting the calculation loop self.ctx.unexpected_failure = False self.report('{}<{}> finished successfully, but sanity check failed, restarting'.format( self.ctx.calc_name, calculation.pk)) return self.report('{}<{}> completed successfully'.format(self.ctx.calc_name, calculation.pk)) self.ctx.restart_calc = calculation self.ctx.is_finished = True return # Unexpected: calculation was killed or an exception occurred, trigger unexpected failure handling if calculation.is_excepted or calculation.is_killed: return self._handle_unexpected_failure(calculation) # Failed: here the calculation is `Finished` but has a non-zero exit status, initiate the error handling try: exit_code = self._handle_calculation_failure(calculation) except UnexpectedCalculationFailure as exception: exit_code = self._handle_unexpected_failure(calculation, exception) return exit_code
[docs] def results(self): """Attach the outputs specified in the output specification from the last completed calculation.""" calculation = self.ctx.calculations[self.ctx.iteration - 1] if calculation.is_failed and self.ctx.iteration >= self.inputs.max_iterations.value: # Abort: exceeded maximum number of retries self.report('reached the maximum number of iterations {}: last ran {}<{}>'.format( self.inputs.max_iterations.value, self.ctx.calc_name, calculation.pk)) return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED self.report('work chain completed after {} iterations'.format(self.ctx.iteration)) for name, port in self.spec().outputs.items(): try: node = calculation.get_outgoing(link_label_filter=name).one().node except ValueError: if port.required: self.report("required output '{}' was not an output of {}<{}>".format( name, self.ctx.calc_name, calculation.pk)) else: self.out(name, node) if self._verbose: self.report("attaching the node {}<{}> as '{}'".format(node.__class__.__name__, node.pk, name))
[docs] def on_terminated(self): """Clean the working directories of all child calculations if `clean_workdir=True` in the inputs.""" super(BaseRestartWorkChain, self).on_terminated() if self.inputs.clean_workdir.value is False: self.report('remote folders will not be cleaned') return cleaned_calcs = [] for called_descendant in self.node.called_descendants: if isinstance(called_descendant, orm.CalcJobNode): try: called_descendant.outputs.remote_folder._clean() # pylint: disable=protected-access cleaned_calcs.append(str(called_descendant.pk)) except (IOError, OSError, KeyError): pass if cleaned_calcs: self.report('cleaned remote folders of calculations: {}'.format(' '.join(cleaned_calcs)))
[docs] def _handle_calculation_sanity_checks(self, calculation): """Perform a sanity check of a calculation that finished ok. Calculations that were marked as successful by the parser may still have produced outputs that do not make sense but were not detected by the code and so were not highlighted as warnings or errors. The consistency of the outputs can be checked here. If an unrecoverable problem is found, the function should return the appropriate exit code to abort the work chain. If the probem can be fixed with a restart calculation, this function should adapt the inputs as an error handler would and return `False`. This will signal to the work chain that a new calculation should be started. If `None` is returned, the work chain assumes that the outputs produced by the calculation are good and nothing will be done. :param calculation: the calculation whose outputs should be checked for consistency :return: `ErrorHandlerReport` if a new calculation should be launched or abort if it includes an exit code """
[docs] def _handle_calculation_failure(self, calculation): """Call the attached error handlers if any to attempt to correct the cause of the calculation failure. The registered error handlers will be called in order based on their priority until a handler returns a report that instructs to break. If the last executed error handler defines an exit code, that will be returned to instruct the work chain to abort. Otherwise the work chain will continue the cycle. :param calculation: the calculation that finished with a non-zero exit status :return: `ExitCode` if the work chain is to be aborted :raises `UnexpectedCalculationFailure`: if no error handlers were registered or no errors were handled. """ is_handled = False handler_report = None if not hasattr(self, '_error_handlers') or not self._error_handlers: raise UnexpectedCalculationFailure('no calculation error handlers were registered') # Sort the handlers with a priority defined, based on their priority in reverse order handlers = [handler for handler in self._error_handlers if handler.priority] handlers = sorted(handlers, key=lambda x: x.priority, reverse=True) for handler in handlers: handler_report = handler.method(self, calculation) # If at least one error is handled, we consider the calculation failure handled. if handler_report and handler_report.is_handled: self.ctx.unexpected_failure = False is_handled = True # After certain error handlers, we may want to skip all other error handling if handler_report and handler_report.do_break: break # If none of the executed error handlers reported that they handled an error, the failure reason is unknown if not is_handled: raise UnexpectedCalculationFailure('calculation failure was not handled') # The last called error handler may not necessarily have returned a handler report if handler_report: return handler_report.exit_code return
[docs] def _handle_unexpected_failure(self, calculation, exception=None): """Handle an unexpected failure. This occurs when a calculation excepted, was killed or finished with a non-zero exit status but no errors were handled. If this is the second consecutive unexpected failure the work chain is aborted. :param calculation: the calculation that failed in an unexpected way :param exception: optional exception or error message to log to the report :return: `ExitCode` if this is the second consecutive unexpected failure """ if exception: self.report('{}'.format(exception)) if self.ctx.unexpected_failure: self.report('failure of {}<{}> could not be handled for the second consecutive time'.format( self.ctx.calc_name, calculation.pk)) return self.exit_codes.ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE self.ctx.unexpected_failure = True self.report('failure of {}<{}> could not be handled, restarting once more'.format( self.ctx.calc_name, calculation.pk))