Source code for caskade.context

from .module import Module
from .param import Param
from .errors import ActiveStateError


class ActiveContext:
    """
    Context manager to activate a module for a simulation. Only inside an
    ActiveContext is it possible to fill/clear the dynamic and live parameters.
    """

    def __init__(self, module: Module):
        self.module = module

    def __enter__(self):
        if self.module.online:
            raise ActiveStateError(f"Module '{self.module.name}' is already running a simulation")
        if self.module.active:
            self.state = list(p._value for p in self.module.all_params)
        else:
            self.state = None
            self.module.add_memo("active")
        self.module.add_memo(f"{self.module.name}_active")

    def __exit__(self, exc_type, exc_value, traceback):
        self.module.clear_state()
        self.module.remove_memo(f"{self.module.name}_active")
        if self.state is not None:
            for p, s in zip(self.module.all_params, self.state):
                p._value = s
        else:
            self.module.remove_memo("active")


[docs] class ValidContext: """ Context manager to set valid values for parameters. Only inside a ValidContext will parameters automatically be assumed valid. """ def __init__(self, module: Module): self.module = module def __enter__(self): self.init_valid = self.module.valid_context self.module.valid_context = True def __exit__(self, exc_type, exc_value, traceback): self.module.valid_context = self.init_valid
[docs] class OverrideParam: """ Context manager to override a parameter value. Only inside an OverrideParam will the parameter be set to the new value. """ def __init__(self, param: Param, value): self.param = param self.value = value @staticmethod def _collect_old_values(param): # Recursively collect the old values for any pointer affected by the override old_values = [(param, param._value)] for node in param.parents: if isinstance(node, Param) and node.pointer: old_values += OverrideParam._collect_old_values(node) node._value = None return old_values def __enter__(self): # Store the old value(s) of the parameter and any pointers that may need updating self.old_values = OverrideParam._collect_old_values(self.param) # Set the new value self.param._value = self.value def __exit__(self, exc_type, exc_value, traceback): # Reset the param and pointer values as they were before the override for node, value in self.old_values: node._value = value