Source code for cooper.formulations.formulations
# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.
import abc
from typing import Any, Literal, NamedTuple, Optional
import torch
import cooper.formulations.utils as formulation_utils
from cooper.constraints.constraint_state import ConstraintState
from cooper.multipliers import Multiplier
from cooper.penalty_coefficients import PenaltyCoefficient
from cooper.utils import ConstraintType
class ContributionStore(NamedTuple):
lagrangian_contribution: torch.Tensor
multiplier_value: Optional[torch.Tensor] = None
penalty_coefficient_value: Optional[torch.Tensor] = None
[docs]
class Formulation(abc.ABC):
"""Formulations prescribe how the different constraints contribute to the primal- and
dual-differentiable Lagrangians. In other words, they prescribe how the constraints
affect the gradients of the Lagrangian with respect to the primal and dual variables.
Attributes:
expects_multiplier (bool): Used to determine whether the formulation requires a
multiplier.
expects_penalty_coefficient (bool): Used to determine whether the formulation
requires a penalty coefficient.
Raises:
ValueError: If the constraint type is not equality or inequality.
"""
expects_multiplier: bool
expects_penalty_coefficient: bool
def __init__(self, constraint_type: ConstraintType) -> None:
if constraint_type not in {ConstraintType.EQUALITY, ConstraintType.INEQUALITY}:
raise ValueError(f"{type(self).__name__} requires either an equality or inequality constraint.")
self.constraint_type = constraint_type
def __repr__(self) -> str:
return f"{type(self).__name__}(constraint_type={self.constraint_type})"
def sanity_check_multiplier(self, multiplier: Optional[Multiplier]) -> None:
"""Ensures that the multiplier is provided if and only if it is expected.
Raises:
ValueError: If a multiplier is expected but not provided, or vice versa.
"""
if self.expects_multiplier and multiplier is None:
raise ValueError(f"{type(self).__name__} expects a multiplier but none was provided.")
if not self.expects_multiplier and multiplier is not None:
raise ValueError(f"Received unexpected multiplier for {type(self).__name__}.")
def sanity_check_penalty_coefficient(self, penalty_coefficient: Optional[PenaltyCoefficient]) -> None:
"""Ensures that the penalty is provided if and only if it is expected.
Raises:
ValueError: If a penalty coefficient is expected but not provided, or vice versa.
"""
if self.expects_penalty_coefficient and penalty_coefficient is None:
raise ValueError(f"{type(self).__name__} expects a penalty coefficient but none was provided.")
if not self.expects_penalty_coefficient and penalty_coefficient is not None:
raise ValueError(f"Received unexpected penalty coefficient for {type(self).__name__}.")
def _prepare_kwargs_for_lagrangian_contribution(
self,
constraint_state: ConstraintState,
multiplier: Optional[Multiplier],
penalty_coefficient: Optional[PenaltyCoefficient],
primal_or_dual: Literal["primal", "dual"],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Prepares the arguments for the computation of the Lagrangian contribution.
Depending on the chosen formulation, the contribution of a constraint to the
Lagrangian would require different inputs. This method processes a
:py:class:`cooper.constraints.constraint_state.ConstraintState` and prepares
the necessary information to compute the contribution to the Lagrangian.
This method extracts and patches the constraint violations and features depending
on the information available. For example, no `strict_violation` is provided,
the call to `extract_violations` will return the same tensor for both `violation`
and `strict_violation` (with the latter tensor having been detached).
It also evaluates the constraint factors (multiplier and penalty coefficient
modules) using the constraint features (if provided).
Args:
constraint_state: The :py:class:`cooper.constraints.constraint_state.ConstraintState`
object.
multiplier: The multiplier module.
penalty_coefficient: The penalty coefficient module.
primal_or_dual: If `"primal"`, we prepare the arguments to compute the
primal-differentiable contribution to the Lagrangian. Analogous for
the case of `"dual"`.
Returns:
A tuple containing the following objects:
If `primal_or_dual == "primal"`:
- violation: The observed constraint violations tensor.
- multiplier_value: The evaluated multiplier factor.
"""
violation, strict_violation = constraint_state.extract_violations()
constraint_features, strict_constraint_features = constraint_state.extract_constraint_features()
if primal_or_dual == "dual":
violation = strict_violation
constraint_features = strict_constraint_features
eval_factor_kwargs = {"constraint_features": constraint_features, "expand_shape": violation.shape}
multiplier_value = None
if self.expects_multiplier:
multiplier_value = formulation_utils.evaluate_constraint_factor(module=multiplier, **eval_factor_kwargs)
penalty_coefficient_value = None
if self.expects_penalty_coefficient:
penalty_coefficient_value = formulation_utils.evaluate_constraint_factor(
module=penalty_coefficient, **eval_factor_kwargs
)
return violation, multiplier_value, penalty_coefficient_value
[docs]
@abc.abstractmethod
def compute_contribution_to_primal_lagrangian(self, *args: Any, **kwargs: Any) -> Optional[ContributionStore]:
"""Computes the contribution of a given constraint violation to the *primal*
Lagrangian.
Returns ``None`` if the constraint does not contribute to the primal update
(i.e., when ``ConstraintState.contributes_to_primal_update=False``).
"""
raise NotImplementedError
[docs]
@abc.abstractmethod
def compute_contribution_to_dual_lagrangian(self, *args: Any, **kwargs: Any) -> Optional[ContributionStore]:
"""Computes the contribution of a given constraint violation to the *dual*
Lagrangian.
Returns ``None`` if the constraint does not contribute to the dual update
(i.e., when ``ConstraintState.contributes_to_dual_update=False``).
"""
raise NotImplementedError
[docs]
class Lagrangian(Formulation):
r"""The Lagrangian formulation implements the following primal Lagrangian:
.. math::
\Lag_{\text{primal}}(\vx, \vlambda, \vmu) = f(\vx) + \vlambda^{\top} \tilde{\vg}(\vx) + \vmu^{\top} \tilde{\vh}(\vx).
And the following dual Lagrangian:
.. math::
\Lag_{\text{dual}}(\vx, \vlambda, \vmu) = \vlambda^{\top} \vg(\vx) + \vmu^{\top} \vh(\vx).
"""
expects_multiplier = True
expects_penalty_coefficient = False
[docs]
def compute_contribution_to_primal_lagrangian(
self, constraint_state: ConstraintState, multiplier: Multiplier
) -> Optional[ContributionStore]:
if not constraint_state.contributes_to_primal_update:
return None
# Third return is `penalty_coefficient_value` which is always `None` for this formulation.
violation, multiplier_value, _ = self._prepare_kwargs_for_lagrangian_contribution(
constraint_state=constraint_state, multiplier=multiplier, penalty_coefficient=None, primal_or_dual="primal"
)
lagrangian_contribution = formulation_utils.compute_primal_weighted_violation(
constraint_factor_value=multiplier_value, violation=violation
)
return ContributionStore(lagrangian_contribution, multiplier_value, None)
[docs]
def compute_contribution_to_dual_lagrangian(
self, constraint_state: ConstraintState, multiplier: Multiplier
) -> Optional[ContributionStore]:
if not constraint_state.contributes_to_dual_update:
return None
# Third return is `penalty_coefficient_value` which is always `None` for this formulation.
violation, multiplier_value, _ = self._prepare_kwargs_for_lagrangian_contribution(
constraint_state=constraint_state, multiplier=multiplier, penalty_coefficient=None, primal_or_dual="dual"
)
lagrangian_contribution = formulation_utils.compute_dual_weighted_violation(
multiplier_value=multiplier_value, violation=violation
)
return ContributionStore(lagrangian_contribution, multiplier_value, None)
[docs]
class QuadraticPenalty(Formulation):
r"""The Quadratic Penalty formulation implements the following primal Lagrangian:
.. math::
\Lag_{\text{primal}}(\vx) = f(\vx) + \frac{1}{2} \vc_{\vg}^\top \,
\texttt{relu}(\tilde{\vg}(\vx))^2 + \frac{1}{2} \vc_{\vh}^\top \,
\tilde{\vh}(\vx)^2.
It does not implement a dual Lagrangian since it does not consider dual variables.
"""
expects_multiplier = False
expects_penalty_coefficient = True
[docs]
def compute_contribution_to_primal_lagrangian(
self, constraint_state: ConstraintState, penalty_coefficient: PenaltyCoefficient
) -> Optional[ContributionStore]:
if not constraint_state.contributes_to_primal_update:
return None
# Second return is `multiplier_value` which is always `None` for this formulation.
violation, _, penalty_coefficient_value = self._prepare_kwargs_for_lagrangian_contribution(
constraint_state=constraint_state,
multiplier=None,
penalty_coefficient=penalty_coefficient,
primal_or_dual="primal",
)
lagrangian_contribution = formulation_utils.compute_quadratic_penalty(
penalty_coefficient_value=penalty_coefficient_value,
violation=violation,
constraint_type=self.constraint_type,
)
return ContributionStore(lagrangian_contribution, None, penalty_coefficient_value)
[docs]
def compute_contribution_to_dual_lagrangian( # noqa: PLR6301
self,
constraint_state: ConstraintState, # noqa: ARG002
penalty_coefficient: PenaltyCoefficient, # noqa: ARG002
) -> None:
"""The Quadratic Penalty formulation does not involve dual variables and
therefore does not implement a dual Lagrangian (returns ``None``).
"""
return
[docs]
class AugmentedLagrangian(Formulation):
r"""The Augmented Lagrangian formulation implements the following primal Lagrangian:
.. math::
\Lag_{\text{primal}}(\vx, \vlambda, \vmu) = f(\vx) + \vlambda^{\top}
\tilde{\vg}(\vx) + \vmu^{\top} \tilde{\vh}(\vx) + \frac{1}{2} \vc_{\vg}^\top \,
\texttt{relu}(\tilde{\vg}(\vx))^2 + \frac{1}{2} \vc_{\vh}^\top \,
\tilde{\vh}(\vx)^2.
And the following dual Lagrangian:
.. math::
\Lag_{\text{dual}}(\vx, \vlambda, \vmu) = \vlambda^{\top} \vg(\vx) + \vmu^{\top} \vh(\vx).
"""
expects_multiplier = True
expects_penalty_coefficient = True
[docs]
def compute_contribution_to_primal_lagrangian(
self, constraint_state: ConstraintState, multiplier: Multiplier, penalty_coefficient: PenaltyCoefficient
) -> Optional[ContributionStore]:
if not constraint_state.contributes_to_primal_update:
return None
violation, multiplier_value, penalty_coefficient_value = self._prepare_kwargs_for_lagrangian_contribution(
constraint_state=constraint_state,
multiplier=multiplier,
penalty_coefficient=penalty_coefficient,
primal_or_dual="primal",
)
lagrangian_contribution = formulation_utils.compute_primal_quadratic_augmented_contribution(
multiplier_value=multiplier_value,
penalty_coefficient_value=penalty_coefficient_value,
violation=violation,
constraint_type=self.constraint_type,
)
return ContributionStore(lagrangian_contribution, multiplier_value, penalty_coefficient_value)
[docs]
def compute_contribution_to_dual_lagrangian(
self, constraint_state: ConstraintState, multiplier: Multiplier, penalty_coefficient: PenaltyCoefficient
) -> Optional[ContributionStore]:
if not constraint_state.contributes_to_dual_update:
return None
violation, multiplier_value, penalty_coefficient_value = self._prepare_kwargs_for_lagrangian_contribution(
constraint_state=constraint_state,
multiplier=multiplier,
penalty_coefficient=penalty_coefficient,
primal_or_dual="dual",
)
# Not providing a penalty coefficient since the dual Lagrangian is just the
# sum of the violation times the multiplier.
lagrangian_contribution = formulation_utils.compute_dual_weighted_violation(
multiplier_value=multiplier_value, violation=violation
)
return ContributionStore(lagrangian_contribution, multiplier_value, penalty_coefficient_value)