Source code for caskade.context
from .module import Module
from .param import Param
from .errors import ActiveStateError
class ActiveContext:
"""
Context manager to activate a module for a simulation. Only inside an
ActiveContext is it possible to fill/clear the dynamic and live parameters.
"""
def __init__(self, module: Module):
self.module = module
def __enter__(self):
if self.module.online:
raise ActiveStateError(f"Module '{self.module.name}' is already running a simulation")
if self.module.active:
self.state = list(p._value for p in self.module.all_params)
else:
self.state = None
self.module.add_memo("active")
self.module.add_memo(f"{self.module.name}_active")
def __exit__(self, exc_type, exc_value, traceback):
self.module.clear_state()
self.module.remove_memo(f"{self.module.name}_active")
if self.state is not None:
for p, s in zip(self.module.all_params, self.state):
p._value = s
else:
self.module.remove_memo("active")
[docs]
class ValidContext:
"""
Context manager to set valid values for parameters. Only inside a
ValidContext will parameters automatically be assumed valid.
"""
def __init__(self, module: Module):
self.module = module
def __enter__(self):
self.init_valid = self.module.valid_context
self.module.valid_context = True
def __exit__(self, exc_type, exc_value, traceback):
self.module.valid_context = self.init_valid
[docs]
class OverrideParam:
"""
Context manager to override a parameter value. Only inside an
OverrideParam will the parameter be set to the new value.
"""
def __init__(self, param: Param, value):
self.param = param
self.value = value
@staticmethod
def _collect_old_values(param):
# Recursively collect the old values for any pointer affected by the override
old_values = [(param, param._value)]
for node in param.parents:
if isinstance(node, Param) and node.pointer:
old_values += OverrideParam._collect_old_values(node)
node._value = None
return old_values
def __enter__(self):
# Store the old value(s) of the parameter and any pointers that may need updating
self.old_values = OverrideParam._collect_old_values(self.param)
# Set the new value
self.param._value = self.value
def __exit__(self, exc_type, exc_value, traceback):
# Reset the param and pointer values as they were before the override
for node, value in self.old_values:
node._value = value