Source code for lembas.case

from __future__ import annotations

import inspect
import itertools
import logging
import types
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
from functools import WRAPPER_ASSIGNMENTS
from functools import cached_property
from pathlib import Path
from typing import Any
from typing import ClassVar
from typing import Generic
from typing import Optional
from typing import TypeVar

import toml

from lembas.logging import logger
from lembas.param import InputParameter
from lembas.results import Results

__all__ = ["Case", "CaseList", "step"]


LEMBAS_CASE_TOML_FILENAME = Path("lembas", "case.toml")

TCase = TypeVar("TCase", bound="Case")
RawCaseStepMethod = Callable[[TCase], None]


class CaseStep:
    def __init__(
        self,
        func: RawCaseStepMethod,
        *,
        condition: Callable[[Any], bool] | str | None = None,
        requires: str | Iterable[str] | None = None,
    ):
        self._func = func
        self._condition = self._validate_condition(condition)
        self.requires = (
            [requires] if isinstance(requires, str) else list(requires or [])
        )

    @cached_property
    def name(self) -> str:
        """The name of the case step."""
        return self._func.__name__

    @staticmethod
    def _validate_condition(
        condition: Callable[[Any], bool] | str | None,
    ) -> Callable[[Any], bool]:
        if condition is None:
            return lambda _: True

        if isinstance(condition, str):
            parts = condition.split()
            if len(parts) == 1:
                return lambda case: getattr(case, parts[0].strip())
            elif len(parts) == 2:
                # The only two-part form allowed is "not attribute_name"
                if parts[0].strip().lower() != "not":
                    raise ValueError(
                        "Can only use 'not' as modifier for string-based condition"
                    )
                return lambda case: not getattr(case, parts[1].strip())
            else:
                raise ValueError(
                    "A string-based condition can only be of the 'attribute_name' or "
                    "'not attribute_name' form"
                )
        return condition

    def __call__(self, instance: Case) -> None:
        if self._condition(instance):
            return self._func(instance)
        return None

    def __get__(self, instance: Any, cls: Any) -> types.MethodType:
        """We need to implement __get__ so that the `CaseStep` can be used as a method.

        Without doing this, Python will treat it as a normal attribute access, rather
        than as a descriptor.

        """
        return types.MethodType(self, instance)


[docs] class Case: """Base case for all cases. When constructing a new case, all assigned ``InputAttribute`` values can be set via keyword arguments. """ _steps: ClassVar[dict[str, CaseStep]] results: Results def __init_subclass__(cls, **kwargs: Any): cls._steps = { name: method for name, method in cls.__dict__.items() if isinstance(method, CaseStep) } def __init__(self, **kwargs: Any): self._completed_steps: set[str] = set() for name, value in kwargs.items(): setattr(self, name, value) self.results = Results(parent=self) def __str__(self) -> str: cls = self.__class__ lines = [f"{cls.__name__}:"] for name, value in cls.__dict__.items(): if isinstance(value, InputParameter): lines.append(f" - {name}: {getattr(self, name)}") return "\n".join(lines)
[docs] @staticmethod def log(msg: str, *args: Any, level: int = logging.INFO) -> None: """Log a message to the logger.""" logger.log(level, msg, *args)
@property def fully_resolved_name(self) -> str: """The fully-resolved import path of the case handler.""" cls = self.__class__ mod = inspect.getmodule(cls) mod_prefix = mod.__name__ + "." if mod is not None else "" return mod_prefix + cls.__qualname__ @cached_property def case_dir(self) -> Optional[Path]: f"""An optional property that can be set for a subclass. If set, the `{LEMBAS_CASE_TOML_FILENAME}` file will be written in this directory. """ return None @cached_property def relative_case_dir(self) -> Optional[Path]: """Return a case directory relative to current working directory if possible. If the case directory is not a subdirectory of current working directory, the absolute path is returned. """ if self.case_dir is None: return None try: return self.case_dir.relative_to(Path.cwd()) except ValueError: return self.case_dir @property def inputs(self) -> dict[str, Any]: """A mapping of the name of each InputAttribute to its value.""" attr_names = ( k for k, v in self.__class__.__dict__.items() if isinstance(v, InputParameter) and v.include_in_inputs_dict ) return {n: getattr(self, n) for n in attr_names} @property def _sorted_steps(self) -> Iterator[CaseStep]: """Yield the case steps in order, with proper sorting of dependencies.""" steps = dict(self._steps) while steps: for name, step in steps.items(): if not step.requires or (set(step.requires).issubset(self._completed_steps)): # type: ignore yield steps.pop(name) self._completed_steps.add(name) break def _write_lembas_file(self) -> None: """Write a file in the case directory specifying the case handler and all input values used.""" if self.relative_case_dir is None: return None case_summary_file = self.relative_case_dir / LEMBAS_CASE_TOML_FILENAME case_summary_file.parent.mkdir(parents=True, exist_ok=True) self.log("Writing case summary to: %s", case_summary_file) contents = { "lembas": {"inputs": self.inputs, "case-handler": self.fully_resolved_name} } with case_summary_file.open("w") as fp: toml.dump(contents, fp)
[docs] def run(self) -> None: """Run the case. If this method is not overridden, the default behavior is to run all the methods decorated with ``@step``. """ self.log("Running %s", self) self._write_lembas_file() for step_method in self._sorted_steps: step_method(self)
[docs] class CaseList(Generic[TCase]): """A generic collection of ``Case`` objects, and utility methods to create and run them. Args: cases: An optional iterable of ``Case`` objects used to initialize the ``CaseList``. """ def __init__(self, cases: Iterable[TCase] | None = None): self._cases: list[TCase] = list(cases or ())
[docs] def add(self, case: TCase) -> TCase: """Add a case to the list: Args: case: The case to add. Returns: The case that was added. """ self._cases.append(case) return case
[docs] def add_cases_by_parameter_sweep( self, case_class: type[TCase], **kwargs: Any ) -> None: """Add a number of cases by performing a parameter sweep using the Cartesian product. Args: case_class: The type of case to construct. kwargs: Any parameters to pass to the case constructors. If iterable values are provided, they will be used when performing the parameter sweep via ``itertools.product``. """ # Ensure all kwargs have iterable values by wrapping scalars and strings for key, value in kwargs.items(): if isinstance(value, str) or not isinstance(value, Iterable): kwargs[key] = [value] for values in itertools.product(*kwargs.values()): new_kwargs = {k: v for k, v in zip(kwargs.keys(), values)} case = case_class(**new_kwargs) self.add(case)
[docs] def run_all(self) -> None: """Run all the cases.""" for case in self._cases: case.run()
def __contains__(self, item: TCase) -> bool: return item in self._cases def __len__(self) -> int: return len(self._cases) def __iter__(self) -> Iterator[TCase]: for case in self._cases: yield case
[docs] def step( method: RawCaseStepMethod | None = None, /, condition: Callable[[Any], bool] | str | None = None, requires: str | Iterable[str] | None = None, ) -> Any: """A decorator to define steps to be performed when running a `Case`. The step should not return a value. Args: method: The decorator may be used without any arguments, in which case the defaults will be used. condition: an optional callable which can be used to determine whether the step should run. It will receive the `Case` instance as its only argument, and must return a boolean which, if True, the step will run. Otherwise, it will be skipped. If a string is provided, the condition will be evaluated by performing an attribute lookup on the case, e.g. condition="plot" evaluates to lambda case: case.plot. You may also place the word "not" in front of the attribute, e.g. condition="not plot", which evaluates to lambda case: not case.plot. requires: An iterable of dependent steps on which this one depends, or a single string. Usage: .. code-block:: class MyCase(Case): @step(condition=lambda case: case.case_dir.exists()) def some_analysis_step(self): # do something """ def decorator(f: RawCaseStepMethod) -> CaseStep: new_method = CaseStep(f, condition=condition, requires=requires) # This is largely a replica of functools.wraps, which doesn't seem to work for attr in WRAPPER_ASSIGNMENTS: setattr(new_method, attr, getattr(f, attr, None)) return new_method if method is not None: # handle case where there are no arguments return decorator(method) return decorator