Source code for cooper.optim.constrained_optimizers.extrapolation_optimizer
# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.
"""Implementation of the :py:class:`ExtrapolationConstrainedOptimizer` class."""
from typing import Optional
import torch
from cooper.optim.constrained_optimizers.constrained_optimizer import ConstrainedOptimizer
from cooper.optim.optimizer import RollOut
[docs]
class ExtrapolationConstrainedOptimizer(ConstrainedOptimizer):
r"""Optimizes a :py:class:`~cooper.ConstrainedMinimizationProblem` by performing
extrapolation updates to the primal and dual variables.
Given the choice of primal and dual optimizers, an **extrapolation** step is performed
first:
.. math::
\vx_{t+\frac{1}{2}} &= \texttt{primal_optimizer_update} \left( \vx_{t},
\nabla_{\vx} \Lag_{\text{primal}}(\vx, \vlambda_t, \vmu_t)|_{\vx=\vx_t}
\right)
\vlambda_{t+\frac{1}{2}} &= \left[ \texttt{dual_optimizer_update} \left(
\vlambda_{t}, \nabla_{\vlambda} \Lag_{\text{dual}}({\vx_{t}}, \vlambda,
\vmu_t) |_{\vlambda=\vlambda_t} \right) \right]_+
\vmu_{t+\frac{1}{2}} &= \texttt{dual_optimizer_update} \left( \vmu_{t},
\nabla_{\vmu} \Lag_{\text{dual}}({\vx_{t}}, \vlambda_{t}, \vmu)
|_{\vmu=\vmu_t} \right).
This is followed by an **update** step, which modifies the primal and dual variables
from step :math:`t`, based on the gradients *computed at the extrapolated points*
:math:`t+\frac{1}{2}`:
.. math::
\vx_{t+1} &= \texttt{primal_optimizer_update} \left( \vx_{t}, \nabla_{\vx}
\Lag_{\text{primal}} \left(\vx, \vlambda_{\color{red} t+\frac{1}{2}},
\vmu_{\color{red} t+\frac{1}{2}} \right)|_{\vx=\vx_{\color{red}
t+\frac{1}{2}}} \right)
\vlambda_{t+1} &= \left[ \texttt{dual_optimizer_update} \left(
\vlambda_{t}, \nabla_{\vlambda} \Lag_{\text{dual}} \left({\vx_{\color{red}
t+\frac{1}{2}}}, \vlambda, \vmu_{\color{red} t+\frac{1}{2}} \right)
|_{\vlambda=\vlambda_{\color{red} t+\frac{1}{2}}}\right) \right]_+
\vmu_{t+1} &= \texttt{dual_optimizer_update} \left( \vmu_{t}, \nabla_{\vmu}
\Lag_{\text{dual}}\left({\vx_{\color{red} t+\frac{1}{2}}},
\vlambda_{\color{red} t+\frac{1}{2}}, \vmu \right) |_{\vmu=\vmu_{\color{red}
t+\frac{1}{2}}} \right).
For example, if the primal optimizer is gradient descent and the dual optimizer is
gradient ascent, the extrapolation step leads to:
.. math::
\vx_{t+\frac{1}{2}} &= \vx_t - \eta_{\vx} \left [ \nabla_{\vx} f(\vx_t) +
\vlambda_t^\top \nabla_{\vx} \vg(\vx_t) + \vmu_t^\top \nabla_{\vx}
\vh(\vx_t) \right ],
\vlambda_{t+\frac{1}{2}} &= \left [ \vlambda_t + \eta_{\vlambda} \vg(\vx_{t})
\right ]_+,
\vmu_{t+\frac{1}{2}} &= \vmu_t + \eta_{\vmu} \vh(\vx_t).
The update step then yields:
.. math::
\vx_{t+1} &= \vx_t - \eta_{\vx} \left [ \nabla_{\vx} f \left(\vx_{\color{red}
t+\frac{1}{2}}\right) + \vlambda_{\color{red} t+\frac{1}{2}}^\top
\nabla_{\vx} \vg \left(\vx_{\color{red} t+\frac{1}{2}} \right) +
\vmu_{\color{red} t+\frac{1}{2}}^\top \nabla_{\vx} \vh\left(\vx_{\color{red}
t+\frac{1}{2}} \right) \right ],
\vlambda_{t+1} &= \left [ \vlambda_{t+\frac{1}{2}} + \eta_{\vlambda}
\vg(\vx_{\color{red} t+\frac{1}{2}}) \right ]_+,
\vmu_{t+1} &= \vmu_{t+\frac{1}{2}} + \eta_{\vmu} \vh(\vx_{\color{red}
t+\frac{1}{2}}).
The :py:meth:`~cooper.optim.constrained_optimizers.ExtrapolationConstrainedOptimizer.roll()`
will simultaneously call the
:py:meth:`~cooper.optim.torch_optimizers.ExtragradientOptimizer.extrapolation()`
and :py:meth:`~cooper.optim.torch_optimizers.ExtragradientOptimizer.step()`
methods of the primal and dual optimizers.
"""
[docs]
def custom_sanity_checks(self) -> None:
"""Perform custom sanity checks on the initialization of the optimizer.
Raises:
RuntimeError: Tried to construct an
:py:class:`ExtrapolationConstrainedOptimizer` but some of the provided
optimizers do not have an extrapolation method.
"""
are_primal_extra_optims = [hasattr(_, "extrapolation") for _ in self.primal_optimizers]
are_dual_extra_optims = [hasattr(_, "extrapolation") for _ in self.dual_optimizers]
if not all(are_primal_extra_optims) or not all(are_dual_extra_optims):
raise RuntimeError(
"""Some of the provided optimizers do not have an extrapolation method.
Please ensure that all optimizers are extrapolation capable."""
)
[docs]
@torch.no_grad()
def primal_extrapolation_step(self) -> None:
"""Perform an extrapolation step on the parameters associated with the primal
variables.
"""
for primal_optimizer in self.primal_optimizers:
primal_optimizer.extrapolation()
[docs]
@torch.no_grad()
def dual_extrapolation_step(self) -> None:
"""Perform an extrapolation step on the parameters associated with the dual
variables.
After being updated by the dual optimizer steps, the multipliers are
post-processed (e.g. to ensure non-negativity for inequality constraints).
"""
# Update multipliers based on current constraint violations (gradients)
# For unobserved constraints the gradient is None, so this is a no-op.
for dual_optimizer in self.dual_optimizers:
dual_optimizer.extrapolation()
for multiplier in self.cmp.multipliers():
multiplier.post_step_()
[docs]
def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut:
r"""Performs a full update step on the primal and dual variables.
Note that the forward and backward computations are carried out
*twice*, as part of the
:py:meth:`~cooper.optim.torch_optimizers.ExtragradientOptimizer.extrapolation()`
and :py:meth:`~cooper.optim.torch_optimizers.ExtragradientOptimizer.step()`
calls.
Args:
compute_cmp_state_kwargs: Keyword arguments to pass to the
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state()`
method.
Returns:
:py:class:`~cooper.optim.optimizer.RollOut`: A named tuple containing the
following objects:
- loss (:py:class:`~torch.Tensor`):
The loss value computed after the extrapolation step :math:`f(\vx_{t})`.
- cmp_state (:py:class:`~cooper.CMPState`):
The CMP state at :math:`\vx_{t}`.
- primal_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The primal Lagrangian store at :math:`\vx_{t}`,
:math:`\vlambda_{t}` and :math:`\vmu_{t}`.
- dual_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The dual Lagrangian store at :math:`\vx_{t}`, :math:`\vlambda_t` and
:math:`\vmu_t`.
.. note::
The `RollOut` for this scheme returns the loss and `CMPState` values at the
original point :math:`(\vx_t, \vlambda_t)`, *before* any of the updates are
performed.
"""
if compute_cmp_state_kwargs is None:
compute_cmp_state_kwargs = {}
for call_extrapolation in (True, False):
self.zero_grad()
cmp_state = self.cmp.compute_cmp_state(**compute_cmp_state_kwargs)
primal_lagrangian_store = cmp_state.compute_primal_lagrangian()
dual_lagrangian_store = cmp_state.compute_dual_lagrangian()
if call_extrapolation:
roll_out = RollOut(cmp_state.loss, cmp_state, primal_lagrangian_store, dual_lagrangian_store)
primal_lagrangian_store.backward()
dual_lagrangian_store.backward()
if call_extrapolation:
self.primal_extrapolation_step()
self.dual_extrapolation_step()
else:
self.primal_step()
self.dual_step()
return roll_out