Source code for cooper.lagrangian_formulation
"""Lagrangian formulation"""
import abc
from typing import Callable, List, Optional, Tuple, Union, no_type_check
import torch
from .multipliers import DenseMultiplier
from .problem import CMPState, ConstrainedMinimizationProblem, Formulation
[docs]class BaseLagrangianFormulation(Formulation, metaclass=abc.ABCMeta):
"""
Base class for Lagrangian Formulations.
Attributes:
cmp: :py:class:`~cooper.problem.ConstrainedMinimizationProblem` we aim
to solve and which gives rise to the Lagrangian.
ineq_multipliers: Trainable
:py:class:`cooper.multipliers.DenseMultiplier`\\s associated with
the inequality constraints.
eq_multipliers: Trainable
:py:class:`cooper.multipliers.DenseMultiplier`\\s associated with
the equality constraints.
"""
def __init__(
self,
cmp: ConstrainedMinimizationProblem,
ineq_init: Optional[torch.Tensor] = None,
eq_init: Optional[torch.Tensor] = None,
aug_lag_coefficient: float = 0.0,
):
"""Construct new `LagrangianFormulation`"""
self.cmp = cmp
self.ineq_multipliers = None
self.eq_multipliers = None
# Store user-provided initializations for dual variables
self.ineq_init = ineq_init
self.eq_init = eq_init
self.state_update: List[torch.Tensor] = []
if aug_lag_coefficient < 0:
raise ValueError("Augmented Lagrangian coefficient must be non-negative.")
self.aug_lag_coefficient = aug_lag_coefficient
@property
def dual_parameters(self) -> List[torch.Tensor]:
"""Returns a list gathering all dual parameters"""
return [_ for _ in self.state() if _ is not None]
[docs] def state(self) -> Tuple[Union[None, torch.Tensor]]:
"""
Collects all dual variables and returns a tuple containing their
:py:class:`torch.Tensor` values. Note that the *values* are a different
type from the :py:class:`cooper.multipliers.DenseMultiplier` objects.
"""
if self.ineq_multipliers is None:
ineq_state = None
else:
ineq_state = self.ineq_multipliers()
if self.eq_multipliers is None:
eq_state = None
else:
eq_state = self.eq_multipliers()
return ineq_state, eq_state # type: ignore
[docs] def create_state(self, cmp_state):
"""Initialize dual variables and optimizers given list of equality and
inequality defects. :py:class:`cooper.multipliers.DenseMultiplier`
Args:
eq_defect: Defects for equality constraints
ineq_defect: Defects for inequality constraints.
"""
# Ensure that dual variables are not re-initialized
for constraint_type in ["eq", "ineq"]:
mult_name = constraint_type + "_multipliers"
defect = getattr(cmp_state, constraint_type + "_defect")
proxy_defect = getattr(cmp_state, "proxy_" + constraint_type + "_defect")
has_defect = defect is not None
has_proxy_defect = proxy_defect is not None
if has_defect or has_proxy_defect:
# Ensure dual variables have not been initialized previously
assert getattr(self, constraint_type + "_multipliers") is None
# If given proxy and non-proxy defects, sanity-check shapes
if has_defect and has_proxy_defect:
assert defect.shape == proxy_defect.shape
# Choose a tensor for getting device and dtype information
defect_for_init = defect if has_defect else proxy_defect
init_tensor = getattr(self, constraint_type + "_init")
if init_tensor is None:
# If not provided custom initialization, Lagrange
# multipliers are initialized at 0
# This already preserves dtype and device of defect
casted_init = torch.zeros_like(defect_for_init)
else:
casted_init = torch.tensor(
init_tensor,
device=defect_for_init.device,
dtype=defect_for_init.dtype,
)
assert defect_for_init.shape == casted_init.shape
# Enforce positivity if dealing with inequality
is_positive = constraint_type == "ineq"
multiplier = DenseMultiplier(casted_init, positive=is_positive)
setattr(self, mult_name, multiplier)
@property
def is_state_created(self):
"""
Returns ``True`` if any Lagrange multipliers have been initialized.
"""
return self.ineq_multipliers is not None or self.eq_multipliers is not None
def purge_state_update(self):
self.state_update = []
[docs] def weighted_violation(
self, cmp_state: CMPState, constraint_type: str
) -> torch.Tensor:
"""
Computes the dot product between the current multipliers and the
constraint violations of type ``constraint_type``. If proxy-constraints
are provided in the :py:class:`.CMPState`, the non-proxy (usually
non-differentiable) constraints are used for computing the dot product,
while the "proxy-constraint" dot products are stored under
``self.state_update``.
Args:
cmp_state: current ``CMPState``
constraint_type: type of constrained to be used
"""
defect = getattr(cmp_state, constraint_type + "_defect")
has_defect = defect is not None
proxy_defect = getattr(cmp_state, "proxy_" + constraint_type + "_defect")
has_proxy_defect = proxy_defect is not None
if not has_proxy_defect:
# If not given proxy constraints, then the regular defects are
# used for computing gradients and evaluating the multipliers
proxy_defect = defect
if not has_defect:
# We should always have at least the regular defects, if not, then
# the problem instance does not have `constraint_type` constraints
proxy_violation = torch.tensor([0.0], device=cmp_state.loss.device)
else:
multipliers = getattr(self, constraint_type + "_multipliers")()
# We compute (primal) gradients of this object
proxy_violation = torch.sum(multipliers.detach() * proxy_defect)
# This is the violation of the "actual" constraint. We use this
# to update the value of the multipliers by lazily filling the
# multiplier gradients in `populate_gradients`
violation_for_update = torch.sum(multipliers * defect.detach())
self.state_update.append(violation_for_update)
return proxy_violation
[docs]class LagrangianFormulation(BaseLagrangianFormulation):
"""
Provides utilities for computing the Lagrangian associated with a
``ConstrainedMinimizationProblem`` and for populating the gradients for the
primal and dual parameters.
Args:
cmp: ``ConstrainedMinimizationProblem`` we aim to solve and which gives
rise to the Lagrangian.
ineq_init: Initialization values for the inequality multipliers.
eq_init: Initialization values for the equality multipliers.
aug_lag_coefficient: Coefficient used for the augmented Lagrangian.
"""
[docs] @no_type_check
def composite_objective(
self,
closure: Callable[..., CMPState],
*closure_args,
write_state: bool = True,
**closure_kwargs
) -> torch.Tensor:
"""
Computes the Lagrangian based on a new evaluation of the
:py:class:`~cooper.problem.CMPState``.
If no explicit proxy-constraints are provided, we use the given
inequality/equality constraints to compute the Lagrangian and to
populate the primal and dual gradients.
In case proxy constraints are provided in the CMPState, the non-proxy
constraints (potentially non-differentiable) are used for computing the
Lagrangian, while the proxy-constraints are used in the backward
computation triggered by :py:meth:`._populate_gradient` (and thus must
be differentiable).
Args:
closure: Callable returning a :py:class:`cooper.problem.CMPState`
write_state: If ``True``, the ``state`` of the formulation's
:py:class:`cooper.problem.ConstrainedMinimizationProblem`
attribute is replaced by that returned by the ``closure``
argument. This flag can be used (when set to ``False``) to
evaluate the Lagrangian, e.g. for logging validation metrics,
without overwritting the information stored in the formulation's
:py:class:`cooper.problem.ConstrainedMinimizationProblem`.
"""
cmp_state = closure(*closure_args, **closure_kwargs)
if write_state:
self.cmp.state = cmp_state
# Extract values from ProblemState object
loss = cmp_state.loss
ineq_defect, eq_defect = cmp_state.ineq_defect, cmp_state.eq_defect
if self.cmp.is_constrained and (not self.is_state_created):
# If not done before, instantiate and initialize dual variables
self.create_state(cmp_state)
# Compute Lagrangian based on current loss and values of multipliers
self.purge_state_update()
if self.cmp.is_constrained:
# Compute contribution of the constraint violations, weighted by the
# current multiplier values
# If given proxy constraints, these are used to compute the terms
# added to the Lagrangian, and the multiplier updates are based on
# the non-proxy violations.
# If not given proxy constraints, then gradients and multiplier
# updates are based on the "regular" constraints.
ineq_viol = self.weighted_violation(cmp_state, "ineq")
eq_viol = self.weighted_violation(cmp_state, "eq")
# Lagrangian = loss + \sum_i multiplier_i * defect_i
lagrangian = loss + ineq_viol + eq_viol
# TODO(JGP): [1] verify that current implementation of proxy
# constraints works properly with augmented lagrangian below.
# If using augmented Lagrangian, add squared sum of constraints
# Following the formulation on Marc Toussaint slides (p 17-20)
# https://ipvs.informatik.uni-stuttgart.de/mlr/marc/teaching/13-Optimization/03-constrainedOpt.pdf
if self.aug_lag_coefficient > 0:
# TODO(JGP): [2] I guess one would like to filter based on
# non-proxy feasibility but penalize based on the proxy defect
if ineq_defect is not None:
assert self.ineq_multipliers is not None
ineq_filter = (ineq_defect >= 0) + (self.ineq_multipliers() > 0)
ineq_square = torch.sum(torch.square(ineq_defect[ineq_filter]))
else:
ineq_square = torch.tensor([0.0])
if eq_defect is not None:
eq_square = torch.sum(torch.square(eq_defect))
else:
eq_square = torch.tensor([0.0])
lagrangian += self.aug_lag_coefficient * (ineq_square + eq_square)
else:
assert cmp_state.loss is not None
lagrangian = cmp_state.loss
return lagrangian
@no_type_check
def _populate_gradients(
self, lagrangian: torch.Tensor, ignore_primal: bool = False
):
"""
Performs the actual backward computation which populates the gradients
for the primal and dual variables.
Args:
lagrangian: Value of the computed Lagrangian based on which the
gradients for the primal and dual variables are populated.
ignore_primal: If ``True``, only the gradients with respect to the
dual variables are populated (these correspond to the constraint
violations). This feature is mainly used in conjunction with
``alternating`` updates, which require updating the multipliers
based on the constraints violation *after* having updated the
primal parameters. Defaults to False.
"""
if ignore_primal and self.cmp.is_constrained:
# Only compute gradients wrt Lagrange multipliers
# No need to call backward on Lagrangian as the dual variables have
# been detached when computing the `weighted_violation`s
pass
else:
# Compute gradients wrt primal parameters only.
# The gradient for the dual variables is computed based on the
# non-proxy violations below.
lagrangian.backward()
# Fill in the gradients for the dual variables based on the violation of
# the non-proxy constraint
# This is equivalent to setting `dual_vars.grad = defect`
if self.cmp.is_constrained:
for violation_for_update in self.state_update:
dual_vars = [_ for _ in self.state() if _ is not None]
violation_for_update.backward(inputs=dual_vars)
[docs]class ProxyLagrangianFormulation(BaseLagrangianFormulation):
"""
Placeholder class for the proxy-Lagrangian formulation proposed by
:cite:t:`cotter2019JMLR`.
.. todo::
Implement Proxy-Lagrangian formulation as described in
:cite:t:`cotter2019JMLR`
"""
pass