Source code for cooper.multipliers.multipliers

# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.

"""Classes for modeling dual variables (e.g. Lagrange multipliers)."""

import abc
from typing import Any, Optional

import torch

from cooper.utils import ConstraintType


[docs] class Multiplier(torch.nn.Module, abc.ABC): expects_constraint_features: bool constraint_type: ConstraintType
[docs] @abc.abstractmethod def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: """Return the current value of the multiplier."""
[docs] @abc.abstractmethod def post_step_(self) -> None: """Post-step function for multipliers. This function is called after each step of the dual optimizer, and allows for additional post-processing of the implicit multiplier module or its parameters. """
[docs] def sanity_check(self) -> None: """Perform sanity checks on the multiplier. This method is called after setting the constraint type and ensures consistency between the multiplier and the constraint type. For example, multipliers for inequality constraints must be non-negative. """
def set_constraint_type(self, constraint_type: ConstraintType) -> None: self.constraint_type = constraint_type self.sanity_check()
[docs] class ExplicitMultiplier(Multiplier): """An ExplicitMultiplier holds a :py:class:`torch.nn.parameter.Parameter` (`weight`) which explicitly contains the value of the Lagrange multipliers associated with a :py:class:`~cooper.constraints.Constraint` in a :py:class:`~cooper.cmp.ConstrainedMinimizationProblem`. Args: num_constraints: Number of constraints associated with the multiplier. init: Tensor used to initialize the multiplier values. If both ``init`` and ``num_constraints`` are provided, ``init`` must have shape ``(num_constraints,)``. device: Device for the multiplier. If ``None``, the device is inferred from the ``init`` tensor or the default device. dtype: Data type for the multiplier. Default is ``torch.float32``. """ def __init__( self, num_constraints: Optional[int] = None, init: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> None: super().__init__() self.weight = self.initialize_weight(num_constraints=num_constraints, init=init, device=device, dtype=dtype)
[docs] @staticmethod def initialize_weight( num_constraints: Optional[int], init: Optional[torch.Tensor], device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """Initialize the weight of the multiplier. If both ``init`` and ``num_constraints`` are provided (and the shapes are consistent), ``init`` takes precedence. Otherwise, the weight is initialized to :py:func:`torch.zeros` of shape ``(num_constraints,)``. Raises: ValueError: If both ``num_constraints`` and ``init`` are ``None``. ValueError: If both ``num_constraints`` and ``init`` are provided but their shapes are inconsistent. ValueError: If the provided ``init`` is not a 1D tensor. """ if num_constraints is None and init is None: raise ValueError("At least one of `num_constraints` and `init` must be provided.") if num_constraints is not None and init is not None and num_constraints != init.shape[0]: raise ValueError(f"Inconsistent `init` shape {init.shape} and `num_constraints={num_constraints}") if init is not None: if init.dim() != 1: raise ValueError("`init` must be a 1D tensor of shape `(num_constraints,)`.") return torch.nn.Parameter(init.to(device=device, dtype=dtype)) return torch.nn.Parameter(torch.zeros(num_constraints, device=device, dtype=dtype))
@property def device(self) -> torch.device: return self.weight.device
[docs] def sanity_check(self) -> None: """Ensures multipliers for inequality constraints are non-negative. Raises: ValueError: If the multiplier is associated with an inequality constraint and any of its entries is negative. """ if self.constraint_type == ConstraintType.INEQUALITY and torch.any(self.weight.data < 0): raise ValueError("For inequality constraint, all entries in multiplier must be non-negative.")
[docs] @torch.no_grad() def post_step_(self) -> None: """Projects (in-place) multipliers associated with inequality constraints so that they remain non-negative. This function is called after each dual optimizer step. """ if self.constraint_type == ConstraintType.INEQUALITY: # Ensures non-negativity for multipliers associated with inequality constraints. self.weight.data = torch.relu(self.weight.data)
def __repr__(self) -> str: return f"{type(self).__name__}(num_constraints={self.weight.shape[0]})"
[docs] class DenseMultiplier(ExplicitMultiplier): r"""Sub-class of :py:class:`~cooper.multipliers.ExplicitMultiplier` for constraints that are all evaluated at every optimization step. """ expects_constraint_features = False
[docs] def forward(self) -> torch.Tensor: """Returns the current value of the multiplier.""" return torch.clone(self.weight)
[docs] class IndexedMultiplier(ExplicitMultiplier): r""":py:class:`~cooper.multipliers.ExplicitMultiplier` for indexed constraints which are evaluated only for a subset of constraints on every optimization step. """ expects_constraint_features = True def __init__( self, num_constraints: Optional[int] = None, init: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> None: super().__init__(num_constraints, init, device, dtype) if self.weight.dim() == 1: # To use the forward call in F.embedding, we must reshape the weight to be a # 2-dim tensor self.weight.data = self.weight.data.unsqueeze(-1)
[docs] def forward(self, indices: torch.Tensor) -> torch.Tensor: """Return the current value of the multiplier at the provided indices. Args: indices: Indices of the multipliers to return. The shape of ``indices`` must be ``(num_indices,)``. Raises: ValueError: If ``indices`` dtype is not ``torch.long``. """ if indices.dtype != torch.long: # Not allowing for boolean "indices", which are treated as indices by # torch.nn.functional.embedding and *not* as masks. raise ValueError("Indices must be of type torch.long.") # TODO(gallego-posada): Document sparse gradients are expected for stateful # optimizers (having buffers) multiplier_values = torch.nn.functional.embedding(indices, self.weight, sparse=True) # Flatten multiplier values to 1D since Embedding works with 2D tensors. return torch.flatten(multiplier_values)
[docs] class ImplicitMultiplier(Multiplier): """An implicit multiplier is a :py:class:`torch.nn.Module` that computes the value of a Lagrange multiplier associated with a :py:class:`~cooper.constraints.Constraint` based on the "features" for each constraint. The multiplier is *implicitly* represented by its parameters. """
[docs] @abc.abstractmethod def forward(self) -> torch.Tensor: pass
[docs] @abc.abstractmethod def post_step_(self) -> None: """This method is called after each step of the dual optimizer and allows for additional post-processing of the implicit multiplier module or its parameters. """