Source code for cooper.cmp

# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.

import abc
from collections import OrderedDict
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any, Literal, Optional

import torch
from typing_extensions import Self

from cooper.constraints import Constraint, ConstraintState
from cooper.multipliers import Multiplier
from cooper.penalty_coefficients import PenaltyCoefficient

__all__ = [
    "CMPState",
    "ConstrainedMinimizationProblem",
    "LagrangianStore",
]


[docs] @dataclass class LagrangianStore: """Stores the value of the (primal or dual) Lagrangian, as well as the multiplier and penalty coefficient values for the *observed* constraints. Args: lagrangian: Value of the Lagrangian. multiplier_values: Value of the multipliers associated with the observed constraints. penalty_coefficient_values: Value of the penalty coefficients associated with the observed constraints. """ lagrangian: Optional[torch.Tensor] = None multiplier_values: dict[Constraint, torch.Tensor] = field(default_factory=dict) penalty_coefficient_values: dict[Constraint, torch.Tensor] = field(default_factory=dict) def backward(self) -> None: """Triggers backward calls to compute the gradient of the Lagrangian with respect to the primal variables. """ if self.lagrangian is not None: self.lagrangian.backward() def observed_multiplier_values(self) -> Iterator[torch.Tensor]: yield from self.multiplier_values.values() def observed_penalty_coefficient_values(self) -> Iterator[torch.Tensor]: yield from self.penalty_coefficient_values.values()
[docs] @dataclass class CMPState: r"""Represents the state of a :py:class:`~.ConstrainedMinimizationProblem` in terms of the value of its loss and constraint violations at point :math:`\vx_t`. Args: loss: Value of the loss or main objective to be minimized :math:`f(\vx_t)`. observed_constraints: Dictionary with :py:class:`~Constraint` instances as keys and :py:class:`~ConstraintState` instances as values (containing tensors :math:`\vg(\vx_t)` and :math:`\vh(\vx_t)`). misc: Optional storage space for additional information relevant to the state of the CMP. This dictionary enables persisting the results of certain computations for post-processing. For example, one may want to retain the value of the predictions/logits computed over a given minibatch during the call to :py:meth:`~.ConstrainedMinimizationProblem.compute_cmp_state` to measure or log training statistics. """ loss: Optional[torch.Tensor] = None observed_constraints: dict[Constraint, ConstraintState] = field(default_factory=dict) misc: Optional[dict] = None def _compute_primal_or_dual_lagrangian(self, primal_or_dual: Literal["primal", "dual"]) -> LagrangianStore: """Computes the primal or dual Lagrangian based on the loss and the contribution of the observed constraints. We don't count the loss towards the dual Lagrangian since the objective is not a function of the dual variables. """ check_contributes_fn = lambda cs: getattr(cs, f"contributes_to_{primal_or_dual}_update") contributing_constraints = {c: cs for c, cs in self.observed_constraints.items() if check_contributes_fn(cs)} if not contributing_constraints: # No observed constraints contribute to the Lagrangian. lagrangian = self.loss.clone() if primal_or_dual == "primal" and self.loss is not None else None return LagrangianStore(lagrangian=lagrangian) lagrangian = self.loss.clone() if primal_or_dual == "primal" and self.loss is not None else 0.0 multiplier_values = {} penalty_coefficient_values = {} for constraint, constraint_state in contributing_constraints.items(): contribution_store = constraint.compute_contribution_to_lagrangian(constraint_state, primal_or_dual) if contribution_store is not None: lagrangian = lagrangian + contribution_store.lagrangian_contribution multiplier_values[constraint] = contribution_store.multiplier_value if contribution_store.penalty_coefficient_value is not None: penalty_coefficient_values[constraint] = contribution_store.penalty_coefficient_value return LagrangianStore( lagrangian=lagrangian, multiplier_values=multiplier_values, penalty_coefficient_values=penalty_coefficient_values, )
[docs] def compute_primal_lagrangian(self) -> LagrangianStore: """Computes and accumulates the primal-differentiable Lagrangian based on the loss and the contribution of the observed constraints. """ return self._compute_primal_or_dual_lagrangian(primal_or_dual="primal")
[docs] def compute_dual_lagrangian(self) -> LagrangianStore: """Computes and accumulates the dual-differentiable Lagrangian based on the contribution of the observed constraints. The dual Lagrangian contained in ``LagrangianStore.lagrangian`` ignores the contribution of the loss, since the objective function does not depend on the dual variables. Therefore, ``LagrangianStore.lagrangian = 0`` regardless of the value of ``self.loss``. """ return self._compute_primal_or_dual_lagrangian(primal_or_dual="dual")
[docs] def named_observed_violations(self) -> Iterator[tuple[str, torch.Tensor]]: """Returns an iterator over the observed constraint violations.""" for constraint, constraint_state in self.observed_constraints.items(): yield constraint.name, constraint_state.violation
[docs] def named_observed_strict_violations(self) -> Iterator[tuple[str, torch.Tensor]]: """Returns an iterator over the observed strict constraint violations.""" for constraint, constraint_state in self.observed_constraints.items(): yield constraint.name, constraint_state.strict_violation
[docs] def named_observed_constraint_features(self) -> Iterator[tuple[str, torch.Tensor]]: """Returns an iterator over the observed constraint features.""" for constraint, constraint_state in self.observed_constraints.items(): yield constraint.name, constraint_state.constraint_features
[docs] def named_observed_strict_constraint_features(self) -> Iterator[tuple[str, torch.Tensor]]: """Returns an iterator over the observed strict constraint features.""" for constraint, constraint_state in self.observed_constraints.items(): yield constraint.name, constraint_state.strict_constraint_features
[docs] class ConstrainedMinimizationProblem(abc.ABC): """Template for constrained minimization problems, where subclasses represent specific constrained optimization problems. Subclasses must override the :py:meth:`CMP.compute_cmp_state<.ConstrainedMinimizationProblem.compute_cmp_state>` method. This method should return a :py:class:`~.CMPState` instance that encapsulates the current state of the optimization problem, including the evaluated loss and the values of the constraint violations. """ def __init__(self) -> None: self._constraints = OrderedDict() def _register_constraint(self, name: str, constraint: Constraint) -> None: """Registers a constraint with the CMP. Args: name: Name of the constraint. constraint: Constraint instance to be registered. Raises: TypeError: If attribute value is not a constraint. ValueError: If constraint with `name` already exists. """ if not isinstance(constraint, Constraint): raise TypeError(f"Expected a Constraint instance, got {type(constraint)}") if name in self._constraints: # Allowing for constraint value changes could alter operation of the # optimizers. Users would need to re-build the optimizer to ensure the # multipliers for the new constraint are accessible to the optimizer. raise ValueError(f"Constraint with name {name} already exists") self._constraints[name] = constraint constraint.name = name
[docs] def constraints(self) -> Iterator[Constraint]: """Return an iterator over the registered constraints of the CMP.""" yield from self._constraints.values()
[docs] def named_constraints(self) -> Iterator[tuple[str, Constraint]]: """Return an iterator over the registered constraints of the CMP, yielding tuples of the form ``(constraint_name, constraint)``. """ yield from self._constraints.items()
[docs] def multipliers(self) -> Iterator[Multiplier]: """Returns an iterator over the multipliers associated with the registered constraints of the CMP. """ for constraint in self.constraints(): if constraint.multiplier is not None: yield constraint.multiplier
[docs] def named_multipliers(self) -> Iterator[tuple[str, Multiplier]]: """Returns an iterator over the multipliers associated with the registered constraints of the CMP, yielding tuples of the form ``(constraint_name, multiplier)``. """ for constraint_name, constraint in self.named_constraints(): if constraint.multiplier is not None: yield constraint_name, constraint.multiplier
[docs] def penalty_coefficients(self) -> Iterator[PenaltyCoefficient]: """Returns an iterator over the penalty coefficients associated with the registered constraints of the CMP. Constraints without penalty coefficients are skipped. """ for constraint in self.constraints(): if constraint.penalty_coefficient is not None: yield constraint.penalty_coefficient
[docs] def named_penalty_coefficients(self) -> Iterator[tuple[str, PenaltyCoefficient]]: """Returns an iterator over the penalty coefficients associated with the registered constraints of the CMP, yielding tuples of the form ``(constraint_name, penalty_coefficient)``. Constraints without penalty coefficients are skipped. """ for constraint_name, constraint in self.named_constraints(): if constraint.penalty_coefficient is not None: yield constraint_name, constraint.penalty_coefficient
[docs] def dual_parameters(self) -> Iterator[torch.nn.Parameter]: """Return an iterator over the parameters of the multipliers associated with the registered constraints of the CMP. This method is useful for instantiating the dual optimizers. If a multiplier is shared by several constraints, we only return its parameters once. """ for multiplier in set(self.multipliers()): yield from multiplier.parameters()
[docs] def to(self, *args: Any, **kwargs: Any) -> Self: """Move the CMP to a new device and/or change the dtype of the multipliers and penalty coefficients.""" for constraint in self.constraints(): if constraint.multiplier is not None: constraint.multiplier = constraint.multiplier.to(*args, **kwargs) if constraint.penalty_coefficient is not None: constraint.penalty_coefficient = constraint.penalty_coefficient.to(*args, **kwargs) return self
[docs] def state_dict(self) -> dict: """Returns the state of the CMP. This includes the state of the multipliers and penalty coefficients.""" state_dict = { "multipliers": {name: multiplier.state_dict() for name, multiplier in self.named_multipliers()}, "penalty_coefficients": {name: pc.state_dict() for name, pc in self.named_penalty_coefficients()}, } return state_dict
[docs] def load_state_dict(self, state_dict: dict) -> None: """Loads the state of the CMP. This includes the state of the multipliers and penalty coefficients. Args: state_dict: A state dictionary containing the state of the CMP. """ for name, multiplier_state_dict in state_dict["multipliers"].items(): self._constraints[name].multiplier.load_state_dict(multiplier_state_dict) self._constraints[name].multiplier.sanity_check() for name, penalty_coefficient_state_dict in state_dict["penalty_coefficients"].items(): self._constraints[name].penalty_coefficient.load_state_dict(penalty_coefficient_state_dict) self._constraints[name].penalty_coefficient.sanity_check()
def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, Constraint): self._register_constraint(name, value) else: super().__setattr__(name, value) def __getattr__(self, name: str) -> Any: if name in self._constraints: return self._constraints[name] raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def __delattr__(self, name: str) -> None: if name in self._constraints: del self._constraints[name] else: super().__delattr__(name) def __repr__(self) -> str: repr_str = f"{type(self).__name__}" if len(self._constraints) < 5: # noqa: PLR2004 repr_str += "\n\t(constraints=[\n" for i, (name, constraint) in enumerate(self.named_constraints()): suffix = ",\n" if i < len(self._constraints) - 1 else "\n" repr_str += f"\t\t{name}: {constraint}{suffix}" repr_str += "\t\t]\n\t)" return repr_str
[docs] @abc.abstractmethod def compute_cmp_state(self, *args: Any, **kwargs: Any) -> CMPState: """Computes the state of the CMP based on the current value of the primal parameters. The signature of this function may be adjusted to accommodate situations that require a model, (mini-batched) inputs/targets, or other arguments to be passed. .. note:: When it is prohibitively expensive to compute the loss or constraints exactly, the :py:class:`CMPState` may contain **stochastic estimates**. This is often the case when mini-batches are used to approximate the loss and constraints. Just as in the unconstrained case, these approximations can lead to a compromise in the stability of the optimization process. """
[docs] @staticmethod def sanity_check_cmp_state(cmp_state: CMPState) -> None: """Performs sanity checks on the CMP state. This helper method is useful for ensuring that the CMP state is well-formed. Raises: ValueError: If the loss tensor does not have a valid gradient. ValueError: If the violation tensor of any constraint does not have a valid gradient. ValueError: If the strict violation tensor of any constraint has a gradient. ValueError: If a constraint contributes to the dual update but the associated formulation does not expect a multiplier. """ if cmp_state.loss is not None and cmp_state.loss.grad is None: raise ValueError("The loss tensor must have a valid gradient.") for constraint, constraint_state in cmp_state.observed_constraints.items(): if constraint_state.violation.grad is None: raise ValueError(f"The violation tensor of constraint {constraint} must have a valid gradient.") if constraint_state.strict_violation is not None and constraint_state.strict_violation.grad is not None: raise ValueError(f"The strict violation tensor of constraint {constraint} must not have a gradient.") if not constraint.formulation.expects_multiplier and constraint_state.contributes_to_dual_update: raise ValueError( f"ConstraintState contributes to dual update but formulation {constraint.formulation}" f"associated with constraint {constraint} does not expect a multiplier." )
[docs] def compute_violations(self, *args: Any, **kwargs: Any) -> CMPState: """Computes the violation of the CMP constraints based on the current value of the primal parameters. This function returns a :py:class:`~.CMPState` instance containing the observed constraint values. Note that the returned :py:class:`~.CMPState` may have ``loss=None``, as the loss value is not necessarily computed when only evaluating the constraints. The function signature may be adjusted to accommodate situations that require a model, (mini-batched) inputs/targets, or other arguments. In some cases, the computation of constraints may be independent of loss evaluation. In such situations, :py:meth:`CMP.compute_violations<.ConstrainedMinimizationProblem.compute_violations>` can be called as part of the execution of :py:meth:`CMP.compute_cmp_state<.ConstrainedMinimizationProblem.compute_cmp_state>`. """ raise NotImplementedError