Source code for cooper.penalty_coefficients.penalty_coefficients

# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.

import abc
from typing import Any, Optional

import torch
from typing_extensions import Self


[docs] class PenaltyCoefficient(abc.ABC): """Abstract class for constant (non-trainable) penalty coefficients. Args: init: Value of the penalty coefficient. Raises: ValueError: If ``init`` has two or more dimensions. """ expects_constraint_features: bool _value: Optional[torch.Tensor] = None def __init__(self, init: torch.Tensor) -> None: if init.dim() > 1: raise ValueError("init must either be a scalar or a 1D tensor of shape `(num_constraints,)`.") self.init = init.clone() self.value = init @property def value(self) -> torch.Tensor: """Return the current value of the penalty coefficient.""" return self._value @value.setter def value(self, value: torch.Tensor) -> None: """Update the value of the penalty. Raises: ValueError: if the provided ``value`` has a different shape than the existing one or contains negative entries. """ if value.requires_grad: raise ValueError("PenaltyCoefficient should not require gradients.") if self._value is not None and value.shape != self._value.shape: raise ValueError( f"New shape {value.shape} of PenaltyCoefficient does not match existing shape {self._value.shape}." ) self._value = value.clone() self.sanity_check()
[docs] def to(self, *args: Any, **kwargs: Any) -> Self: """Move the penalty coefficient to a new ``device`` and/or change its ``dtype``. """ self._value = self._value.to(*args, **kwargs) return self
[docs] def state_dict(self) -> dict: """Return the current state of the penalty coefficient.""" return {"value": self._value}
[docs] def load_state_dict(self, state_dict: dict) -> None: """Load the state of the penalty coefficient. Args: state_dict: Dictionary containing the state of the penalty coefficient. """ self._value = state_dict["value"]
[docs] def sanity_check(self) -> None: """Check that the penalty coefficient is well-formed. Raises: ValueError: If the penalty coefficient contains negative entries. """ if torch.any(self._value < 0): raise ValueError("All entries of the penalty coefficient must be non-negative.")
def __repr__(self) -> str: if self.value.numel() <= 10: # noqa: PLR2004 return f"{type(self).__name__}({self.value})" return f"{type(self).__name__}(shape={self.value.shape})"
[docs] @abc.abstractmethod def __call__(self, *args: Any, **kwargs: Any) -> torch.Tensor: """Return the current value of the penalty coefficient."""
[docs] class DensePenaltyCoefficient(PenaltyCoefficient): """Constant (non-trainable) coefficient class used for Augmented Lagrangian formulation.""" expects_constraint_features = False
[docs] @torch.no_grad() def __call__(self) -> torch.Tensor: """Return the current value of the penalty coefficient.""" return self.value.clone()
[docs] class IndexedPenaltyCoefficient(PenaltyCoefficient): """Constant (non-trainable) penalty coefficients. When called, indexed penalty coefficients accept a tensor of indices and return the value of the penalty for a subset of constraints. """ expects_constraint_features = True
[docs] @torch.no_grad() def __call__(self, indices: torch.Tensor) -> torch.Tensor: """Return the current value of the penalty coefficient at the provided indices. Args: indices: Tensor of indices for which to return the penalty coefficient. Raises: ValueError: If ``indices`` is not of type ``torch.long``. """ if indices.dtype != torch.long: # Not allowing for boolean "indices", which are treated as indices by # torch.nn.functional.embedding and *not* as masks. raise ValueError("Indices must be of type torch.long.") if self.value.dim() == 0: return self.value.clone() coefficient_values = torch.nn.functional.embedding(indices, self.value.unsqueeze(1), sparse=False) # Flatten coefficient values to 1D since Embedding works with 2D tensors. return torch.flatten(coefficient_values)