"""
Base classes defining the parameter tree.
"""
from __future__ import annotations
import abc
import contextlib
import copy
import re
from typing import Any, ClassVar, Generator, Generic, Mapping, TypeVar, Union, overload
import numpy as np
from typing_extensions import Literal, Self, TypeAlias
from .util import MetadataDict, ParameterError, UnsetParameterError
ParameterPath: TypeAlias = Union[str, tuple[str, ...]]
T = TypeVar("T")
F = TypeVar("F")
P = TypeVar("P", bound="Parameter")
_VALIDATE_FIELDS: bool = True
[docs]
class Field(Generic[F]):
"""
Base class for field descriptors.
Fields are leaves in the input parameter tree that are responsible for storing
fixed input parameter values.
Each instance of a field may optionally specify a ``title`` and ``unit``, which
help document the meaning of the field and can be used in human-readable
representations.
"""
public_name: str
"""Name of this descriptor in the host class."""
private_name: str
"""Attribute used to store this field's value in the host instance."""
title: str | None
"""Human readable title of this field."""
unit: str | None
"""Unit that this field is measured in."""
dtype: ClassVar[np.dtype[Any]]
"""Numpy dtype used to store sweeps of this field."""
def __init__(self, title: str | None = None, unit: str | None = None):
self.title = title
self.unit = unit
def __set_name__(self, owner: Any, name: str) -> None:
self.public_name = name
self.private_name = f"_{name}"
def __get__(self, instance: T, owner: type[T] | None) -> F:
try:
value = getattr(instance, self.private_name)
except AttributeError:
raise UnsetParameterError(
f"attempted to access unset parameter"
f" {type(instance).__name__}.{self.public_name}"
)
return value
def __set__(self, instance: Any, value: F) -> None:
if _VALIDATE_FIELDS:
self.validate(instance, value)
setattr(instance, self.private_name, value)
def validate(self, instance: Any, value: F) -> None:
pass # TODO settle on validation framework
def metadata(self) -> MetadataDict:
metadata: MetadataDict = {}
if self.title is not None:
metadata["title"] = self.title
if self.unit is not None:
metadata["unit"] = self.unit
return metadata
[docs]
class Parameter(metaclass=ParameterMeta):
"""
Base class for input parameters.
Parameter subclasses should represent some definite physical representation
of a model parameter that RTM engines can optionally implement.
"""
_fields: ClassVar[frozenset[str]]
def __init__(self, **kwargs: Any) -> None:
for name, value in kwargs.items():
if name in self._fields:
setattr(self, name, value)
else:
raise ParameterError(
f"unknown field '{name}' - must be one of {list(self._fields)}"
)
[docs]
def __repr__(self) -> str:
field_parts = []
for name in sorted(self._fields):
try:
field_parts.append(f"{name}={getattr(self, name)!r}")
except UnsetParameterError:
field_parts.append(f"{name}=<UNSET>")
return f"{type(self).__name__}({', '.join(field_parts)})"
def replace(self, *args: Any, **kwargs: Any) -> Self:
duplicate = copy.deepcopy(self)
duplicate.set(*args, **kwargs)
return duplicate
@overload
def set(self, param: ParameterPath, value: Any, /) -> None:
...
@overload
def set(self, updates: Mapping[str, Any], /) -> None:
...
@overload
def set(
self,
/,
**kwargs: Any,
) -> None:
...
def set(
self,
param: Mapping[str, Any] | ParameterPath | None = None,
value: Any | None = None,
/,
**kwargs: Any,
) -> None:
if kwargs:
if param is not None:
raise ValueError(
"kwargs must not be passed when positional arguments are used"
)
for param_path, param_arg in kwargs.items():
self.set(param_path, param_arg)
return
if param is None:
# No positional arguments were given.
# Called no with arguments - do nothing.
return
if isinstance(param, Mapping):
for param_path, param_arg in param.items():
self.set(param_path, param_arg)
return
if value is None:
raise ParameterError(
"value must be specified when first argument is not a mapping"
)
try:
self._set(_parse_parameter_path(param), value)
except Exception as ex:
raise ParameterError(
f"failed to set field '{param}' to '{value}': {ex}"
) from ex
def _set(self, path: tuple[str, ...], value: Any) -> None:
curr_field, *sub_path = path
if sub_path:
try:
sub_param = getattr(self, curr_field)
except AttributeError:
raise ParameterError(f"unknown parameter '{curr_field}")
sub_param._set(sub_path, value)
else:
# Not: can't use hasattr check, since it would invoke __get__ on <UNSET>
# fields.
if curr_field not in vars(self.__class__):
raise ParameterError(f"unknown field '{curr_field}'")
setattr(self, curr_field, value)
@overload
def get_fields(self, style: Literal[".", "__"] = ...) -> list[str]:
...
@overload
def get_fields(self, style: Literal["()"]) -> list[tuple[str, ...]]:
...
[docs]
def get_fields(
self, style: Literal[".", "__", "()"] = "."
) -> list[str] | list[tuple[str, ...]]:
"""Return list containing the paths to all this parameter's terminal fields."""
if style not in (".", "__", "()"):
raise ValueError(f"unknown parameter path style '{style}'")
paths = []
for field_name in self._fields:
field = vars(self.__class__)[field_name]
if isinstance(field, ParameterField):
for p in getattr(self, field_name).get_fields("()"):
paths.append((field_name,) + p)
else:
paths.append((field_name,))
if style != "()":
paths = [style.join(p) for p in paths]
return paths
def get_metadata(self, param: ParameterPath) -> MetadataDict:
return self._get_metadata(_parse_parameter_path(param))
def _get_metadata(self, path: tuple[str, ...]) -> MetadataDict:
curr_field, *sub_path = path
if sub_path:
try:
sub_param = getattr(self, curr_field)
except AttributeError:
raise ParameterError(f"unknown parameter '{curr_field}")
return sub_param._get_metadata(sub_path)
else:
try:
return vars(self.__class__)[curr_field].metadata()
except KeyError:
raise ParameterError(f"unknown field '{curr_field}'")
[docs]
class AbstractParameter(Parameter, metaclass=AbstractParameterMeta):
"""
Base class for abstract input parameters.
Like ``Parameter``, but supports ``abc.ABC`` astract behavior.
"""
def _parse_parameter_path(param_path: str | tuple[str, ...]) -> tuple[str, ...]:
if isinstance(param_path, tuple):
return param_path
return tuple(re.split(r"\.|__", param_path))
[docs]
class ParameterField(Field[P]):
"""
Field containing a swappable parameter.
Creates a branch in the parameter tree.
"""
_parameter_type: type[P]
def __init__(self, parameter_type: type[P], title: str | None = None) -> None:
super().__init__(title=title, unit=None)
self._parameter_type = parameter_type
def validate(self, instance: Any, value: F) -> None:
if not isinstance(value, self._parameter_type):
raise ParameterError(
f"value for {type(instance).__name__}.{self.public_name}"
f" must be subclass of {self._parameter_type.__name__},"
f" got type {type(value).__name__}"
)
[docs]
class StrField(Field[str]):
"""Field taking on a string value."""
dtype = np.dtype(str)
[docs]
class FloatField(Field[float]):
"""Field taking on a float value."""
dtype = np.dtype(float)
[docs]
class IntField(Field[int]):
"""Field taking on an integer value."""
dtype = np.dtype(int)
[docs]
class FloatArrayField(Field[np.ndarray[Any, Any]]):
"""Field taking on a float value."""
# TODO array validation
dtype = np.dtype(float)
[docs]
@contextlib.contextmanager
def validate_fields(flag: bool) -> Generator[None, None, None]:
"""Context manager for temporary enabling or disabling field validation."""
global _VALIDATE_FIELDS
prior = _VALIDATE_FIELDS
try:
_VALIDATE_FIELDS = flag
yield
finally:
_VALIDATE_FIELDS = prior