Source code for cooper.penalty_coefficients.penalty_coefficient_updaters

# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.

from __future__ import annotations

import abc
from typing import TYPE_CHECKING

import torch

from cooper.penalty_coefficients.penalty_coefficients import DensePenaltyCoefficient, IndexedPenaltyCoefficient
from cooper.utils import ConstraintType

if TYPE_CHECKING:
    from cooper.constraints import Constraint, ConstraintState


[docs] class PenaltyCoefficientUpdater(abc.ABC): """Abstract class for updating the penalty coefficient of a constraint."""
[docs] def step(self, observed_constraints: dict[Constraint, ConstraintState]) -> None: r"""Trigger updates on the penalty coefficients for each of the ``observed_constraints``. For each constraint in ``observed_constraints``, this method determines whether its penalty coefficient should be updated. The decision depends on properties like whether the constraint contributes to primal/dual updates and the availability of strict violation measurements. .. admonition:: Primal vs Dual Contributions :class: note - For formulations expecting multipliers (e.g., AugmentedLagrangian), updates occur if: - The constraint contributes to the dual update, **OR** - It contributes to the primal update **and** has a strict violation measurement. - For primal-only formulations (e.g., QuadraticPenalty), updates occur only if the constraint contributes to the primal update. Args: observed_constraints: Dictionary with :py:class:`~Constraint` instances as keys and :py:class:`~ConstraintState` instances as values (containing tensors :math:`\vg(\vx_t)` and :math:`\vh(\vx_t)`). """ for constraint, constraint_state in observed_constraints.items(): if constraint.penalty_coefficient is None: # Skip constraints without penalty coefficients continue if constraint.formulation.expects_multiplier: # If the user provides a "surrogate" as the `violation` in the constraint state, and does not # provide a `strict_violation`, and the constraint is marked as contributing to the dual update, our # convention is that said violation is a dual-valid measurement, and thus can be relied on for # updating the penalty coefficient. contributes_to_dual = constraint_state.contributes_to_dual_update # If we only have the `violation` but not the `strict_violation` we cannot be certain that the given # `violation` is not a surrogate. Therefore, we cannot rely on it to update the penalty coefficients. # On the other hand, if `strict_violation` is given, it can be used to update the penalty coefficients. contributes_to_primal = constraint_state.contributes_to_primal_update has_strict_violation = constraint_state.strict_violation is not None should_update = contributes_to_dual or (contributes_to_primal and has_strict_violation) else: # If we have a primal only formulation (like QuadraticPenalty), # `constraint_state.contributes_to_dual_update` must be `False`. # Therefore, we only update the penalty coefficient if # `constraint_state.contributes_to_primal_update=True`. should_update = constraint_state.contributes_to_primal_update if should_update: self.update_penalty_coefficient_(constraint, constraint_state)
@abc.abstractmethod def update_penalty_coefficient_(self, constraint: Constraint, constraint_state: ConstraintState) -> None: """Update the penalty coefficient of a constraint. Args: constraint: The constraint for which the penalty coefficient is updated. constraint_state: The constraint state of the constraint. """
class FeasibilityDrivenPenaltyCoefficientUpdater(PenaltyCoefficientUpdater, abc.ABC): def __init__(self, violation_tolerance: float, has_restart: bool) -> None: if violation_tolerance < 0.0: raise ValueError("Violation tolerance must be non-negative.") self.violation_tolerance = violation_tolerance self.has_restart = has_restart def update_penalty_coefficient_(self, constraint: Constraint, constraint_state: ConstraintState) -> None: # Extract violations and features _, strict_violation = constraint_state.extract_violations() _, strict_constraint_features = constraint_state.extract_constraint_features() penalty_coefficient = constraint.penalty_coefficient # Get current penalty values if isinstance(penalty_coefficient, DensePenaltyCoefficient): observed_penalty_values = penalty_coefficient() elif isinstance(penalty_coefficient, IndexedPenaltyCoefficient): observed_penalty_values = penalty_coefficient(strict_constraint_features) else: raise TypeError(f"Unsupported penalty coefficient type: {type(penalty_coefficient)}") if constraint.constraint_type == ConstraintType.INEQUALITY: # For inequality constraints, we only consider the non-negative part of the violation. strict_violation = strict_violation.relu() is_scalar = observed_penalty_values.dim() == 0 violation_measure = strict_violation.norm() if is_scalar else strict_violation.abs() # Check where the violation exceeds the allowed tolerance violation_exceeds_tolerance = violation_measure > self.violation_tolerance # Compute base new value new_value = self._compute_updated_penalties(observed_penalty_values, violation_exceeds_tolerance) # Restart the penalty coefficient to its initial value if inequality constraint is satisfied. if self.has_restart and constraint.constraint_type == ConstraintType.INEQUALITY: # The strict violation has relu applied to it, so we can check feasibility by comparing to 0. is_feasible = torch.eq(violation_measure, 0) new_value = torch.where(is_feasible, penalty_coefficient.init, new_value) if isinstance(penalty_coefficient, IndexedPenaltyCoefficient) and new_value.dim() > 0: penalty_coefficient.value[strict_constraint_features] = new_value.detach() else: penalty_coefficient.value = new_value.detach() @abc.abstractmethod def _compute_updated_penalties( self, current_penalty_value: torch.Tensor, should_increase_penalty: torch.Tensor ) -> torch.Tensor: """Compute updated penalty values based on violation status."""
[docs] class MultiplicativePenaltyCoefficientUpdater(FeasibilityDrivenPenaltyCoefficientUpdater): r"""Multiplicative updater for :py:class:`~cooper.penalty_coefficients.PenaltyCoefficient`\s. The penalty coefficient is updated by multiplying it by ``growth_factor`` when the constraint violation is larger than ``violation_tolerance``. Based on Algorithm 17.4 in :cite:t:`nocedal2006NumericalOptimization`. Args: growth_factor: The factor by which the penalty coefficient is multiplied when the constraint is violated beyond ``violation_tolerance``. violation_tolerance: The tolerance for the constraint violation. If the violation is smaller than this tolerance, the penalty coefficient is not updated. The comparison is done at the constraint-level (i.e., each entry of the violation tensor). For equality constraints, the absolute violation is compared to the tolerance. All constraint types use the strict violation (when available) for the comparison. has_restart: Whether to restart the penalty coefficient to its initial value when the inequality constraint is satisfied. This is only applicable to inequality constraints. Raises: ValueError: If the violation tolerance is negative. """ def __init__( self, growth_factor: float = 1.01, violation_tolerance: float = 1e-4, has_restart: bool = True ) -> None: super().__init__(violation_tolerance, has_restart) self.growth_factor = growth_factor def _compute_updated_penalties( self, current_penalty_value: torch.Tensor, should_increase_penalty: torch.Tensor ) -> torch.Tensor: return torch.where(should_increase_penalty, current_penalty_value * self.growth_factor, current_penalty_value)
[docs] class AdditivePenaltyCoefficientUpdater(FeasibilityDrivenPenaltyCoefficientUpdater): r"""Additive updater for :py:class:`~cooper.penalty_coefficients.PenaltyCoefficient`\s. The penalty coefficient is updated by adding ``increment`` when the constraint violation is larger than ``violation_tolerance``. Args: increment: The constant value by which the penalty coefficient is added when the constraint is violated beyond ``violation_tolerance``. violation_tolerance: The tolerance for the constraint violation. If the violation is smaller than this tolerance, the penalty coefficient is not updated. The comparison is done at the constraint-level (i.e., each entry of the violation tensor). For equality constraints, the absolute violation is compared to the tolerance. All constraint types use the strict violation (when available) for the comparison. has_restart: Whether to restart the penalty coefficient to its initial value when the inequality constraint is satisfied. This is only applicable to inequality constraints. Raises: ValueError: If the violation tolerance is negative. """ def __init__(self, increment: float = 1.0, violation_tolerance: float = 1e-4, has_restart: bool = True) -> None: super().__init__(violation_tolerance, has_restart) self.increment = increment def _compute_updated_penalties( self, current_penalty_value: torch.Tensor, should_increase_penalty: torch.Tensor ) -> torch.Tensor: return torch.where(should_increase_penalty, current_penalty_value + self.increment, current_penalty_value)