Source code for cooper.formulation.augmented_lagrangian

"""Lagrangian formulation"""

from typing import Callable, Optional, no_type_check

import torch

from cooper.problem import CMPState, ConstrainedMinimizationProblem

from .lagrangian import LagrangianFormulation


[docs]class AugmentedLagrangianFormulation(LagrangianFormulation): """ Provides utilities for computing the Augmented Lagrangian associated with a ``ConstrainedMinimizationProblem`` and for populating the gradients for the primal and dual parameters accordingly. Args: cmp: ``ConstrainedMinimizationProblem`` we aim to solve and which gives rise to the Lagrangian. ineq_init: Initialization values for the inequality multipliers. eq_init: Initialization values for the equality multipliers. """ def __init__( self, cmp: Optional[ConstrainedMinimizationProblem] = None, ineq_init: Optional[torch.Tensor] = None, eq_init: Optional[torch.Tensor] = None, ): """Construct new `AugmentedLagrangianFormulation`""" super().__init__(cmp=cmp, ineq_init=ineq_init, eq_init=eq_init)
[docs] def weighted_violation( self, cmp_state: CMPState, constraint_type: str ) -> torch.Tensor: """ Computes the dot product between the current multipliers and the constraint violations of type ``constraint_type``. If proxy-constraints are provided in the :py:class:`.CMPState`, the non-proxy (usually non-differentiable) constraints are used for computing the dot product, while the "proxy-constraint" dot products are stored under ``self.accumulated_violation_dot_prod``. If the ``CMPState`` contains proxy _inequality_ constraints, the filtering on whether the constraint is active for the calculation of the Augmented Lagrangian is done based on the value of the non-proxy constraints. Args: cmp_state: current ``CMPState``. constraint_type: type of constrained to be used, e.g. "eq" or "ineq". """ defect = getattr(cmp_state, constraint_type + "_defect") has_defect = defect is not None proxy_defect = getattr(cmp_state, "proxy_" + constraint_type + "_defect") has_proxy_defect = proxy_defect is not None if not has_proxy_defect: # If not given proxy constraints, then the regular defects are # used for computing gradients and evaluating the multipliers proxy_defect = defect if not has_defect: # We should always have at least the regular defects, if not, then # the problem instance does not have `constraint_type` constraints proxy_violation = torch.tensor([0.0], device=cmp_state.loss.device) sq_proxy_violation = torch.tensor([0.0], device=cmp_state.loss.device) else: multipliers = getattr(self, constraint_type + "_multipliers")() if constraint_type == "ineq": # Compute filter based on non-proxy constraint defect const_filter = torch.logical_or(defect >= 0, multipliers > 0).detach() else: # Equality constraints do not need to be filtered const_filter = 1.0 # We compute (primal) gradients of this object sq_proxy_violation = torch.sum(const_filter * (proxy_defect) ** 2) # We compute (primal) gradients of this object proxy_violation = torch.sum(multipliers.detach() * proxy_defect) # This is the violation of the "actual" constraint. We use this # to update the value of the multipliers by lazily filling the # multiplier gradients in `backward` # TODO (JGP): Verify that call to backward is general enough for # Lagrange Multiplier models violation_for_update = torch.sum(multipliers * defect.detach()) self.update_accumulated_violation(update=violation_for_update) return proxy_violation, sq_proxy_violation
[docs] @no_type_check def compute_lagrangian( self, aug_lag_coeff_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], closure: Callable[..., CMPState] = None, *closure_args, pre_computed_state: Optional[CMPState] = None, write_state: Optional[bool] = True, **closure_kwargs ) -> torch.Tensor: """ Computes the Lagrangian based on a new evaluation of the :py:class:`~cooper.problem.CMPState` via the ``closure`` function. If no explicit proxy-constraints are provided, we use the given inequality/equality constraints to compute the Augmented Lagrangian and to populate the primal and dual gradients. Note that gradients are _not_ populated by this function, but rather :py:meth:`.backward`. In case proxy constraints are provided in the CMPState, the non-proxy constraints (potentially non-differentiable) are used for computing the value of the Augmented Lagrangian. The accumulated proxy-constraints are used in the backward computation triggered by :py:meth:`.backward` (and thus must be differentiable). Args: closure: Callable returning a :py:class:`cooper.problem.CMPState` pre_computed_state: Pre-computed CMP state to avoid wasteful computation when only dual gradients are required. write_state: If ``True``, the ``state`` of the formulation's :py:class:`cooper.problem.ConstrainedMinimizationProblem` attribute is replaced by that returned by the ``closure`` argument. This flag can be used (when set to ``False``) to evaluate the Augmented Lagrangian, e.g. for logging validation metrics, without overwritting the information stored in the formulation's :py:class:`cooper.problem.ConstrainedMinimizationProblem`. """ assert ( closure is not None or pre_computed_state is not None ), "At least one of closure or pre_computed_state must be provided" if pre_computed_state is not None: cmp_state = pre_computed_state else: cmp_state = closure(*closure_args, **closure_kwargs) if write_state and self.cmp is not None: self.write_cmp_state(cmp_state) # Extract values from ProblemState object loss = cmp_state.loss if not self.is_state_created: # If not done before, instantiate and initialize dual variables self.create_state(cmp_state) # Purge previously accumulated constraint violations self.update_accumulated_violation(update=None) # Compute Augmented Lagrangian based on current loss and values of multipliers # Compute contribution of the constraint violations, weighted by the # current multiplier values # If given proxy constraints, these are used to compute the terms # added to the Lagrangian, and the multiplier updates are based on # the non-proxy violations. # If not given proxy constraints, then gradients and multiplier # updates are based on the "regular" constraints. ineq_viol, sq_ineq_viol = self.weighted_violation(cmp_state, "ineq") eq_viol, sq_eq_viol = self.weighted_violation(cmp_state, "eq") # Lagrangian = loss + \sum_i multiplier_i * defect_i lagrangian = loss + ineq_viol + eq_viol # Gather all the learning rates for the "parameter groups" of the dual # variables, and check that all the learning rates are the same. dual_lrs = aug_lag_coeff_scheduler.get_last_lr() is_all_dual_lr_equal = all(x == dual_lrs[0] for x in dual_lrs) assert is_all_dual_lr_equal, "All the dual LRs must be the same." # Use the dual learning as the Augmented Lagrangian coefficient to # ensure that gradient-based update will coincide with the update # scheme of the Augmented Lagrangian method. augmented_lagrangian_coefficient = dual_lrs[0] if augmented_lagrangian_coefficient > 0: # If using augmented Lagrangian, add squared sum of constraints # Following the formulation on Marc Toussaint slides (p 17-20) # https://ipvs.informatik.uni-stuttgart.de/mlr/marc/teaching/13-Optimization/03-constrainedOpt.pdf lagrangian += ( 0.5 * augmented_lagrangian_coefficient * (sq_ineq_viol + sq_eq_viol) ) return lagrangian