# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.
"""Implementation of constrained optimizers based on alternation such as
:py:class:`AlternatingPrimalDualOptimizer` and :py:class:`AlternatingDualPrimalOptimizer`.
"""
from typing import Optional
import torch
from cooper.optim.constrained_optimizers.constrained_optimizer import ConstrainedOptimizer
from cooper.optim.optimizer import RollOut
[docs]
class AlternatingPrimalDualOptimizer(ConstrainedOptimizer):
r"""Optimizes a :py:class:`~cooper.ConstrainedMinimizationProblem`
by performing alternating updates, starting with the primal variables.
According to the choice of primal and dual optimizers, updates are performed as follows:
.. math::
\vx_{t+1} &= \texttt{primal_optimizer_update} \left( \vx_{t}, \nabla_{\vx}
\Lag_{\text{primal}}(\vx, \vlambda_t, \vmu_t)|_{\vx=\vx_t} \right)
\vlambda_{t+1} &= \left[ \texttt{dual_optimizer_update} \left( \vlambda_{t},
\nabla_{\vlambda} \Lag_{\text{dual}}({\vx_{\color{red} t+1}}, \vlambda, \vmu_t)|_{\vlambda=\vlambda_t}
\right) \right]_+
\vmu_{t+1} &= \texttt{dual_optimizer_update} \left( \vmu_{t}, \nabla_{\vmu}
\Lag_{\text{dual}}({\vx_{\color{red} t+1}}, \vlambda_t, \vmu)|_{\vmu=\vmu_t} \right)
For instance, when employing alternating projected gradient descent-ascent on a
:py:class:`~cooper.formulations.Lagrangian` formulation, the updates are as follows:
.. math::
\vx_{t+1} &= \vx_t - \eta_{\vx} \left [ \nabla_{\vx} f(\vx_t) + \vlambda_t^\top
\nabla_{\vx} \vg(\vx_t) + \vmu_t^\top \nabla_{\vx} \vh(\vx_t) \right ],
\vlambda_{t+1} &= \left [ \vlambda_t + \eta_{\vlambda} \vg(\vx_{\color{red} t+1}) \right ]_+,
\vmu_{t+1} &= \vmu_t + \eta_{\vmu} \vh(\vx_{\color{red} t+1}),
where :math:`\eta_{\vx}`, :math:`\eta_{\vlambda}`, and :math:`\eta_{\vmu}` are step
sizes.
This optimizer computes constraint violations *twice*: at :math:`\vx_{t}` for the
initial primal update, and again at the updated primal point :math:`\vx_{t+1}`
to update the dual variables. The former are used to compute the primal
Lagrangian :math:`\Lag_{\text{primal}}` while the latter are used to compute the
dual Lagrangian :math:`\Lag_{\text{dual}}`.
.. admonition:: Reducing computational overhead in primal-dual alternating updates
:class: note
To update the dual variables, only the constraint violations
:math:`\vg(\vx_{\color{red} t+1})` and :math:`\vh(\vx_{\color{red} t+1})` are
required, not the objective function value :math:`f(\vx_{\color{red} t+1})`. To
reduce computational overhead, the user can implement the
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_violations()`
method of the CMP and pass the ``compute_violations_kwargs`` argument to
:py:meth:`roll()`. This approach ensures that only the constraint violations
are recomputed at :math:`\vx_{\color{red} t+1}`, without calculating the loss or
constructing a computational graph over the primal variables.
"""
[docs]
def roll(
self, compute_cmp_state_kwargs: Optional[dict] = None, compute_violations_kwargs: Optional[dict] = None
) -> RollOut:
r"""Performs a primal-dual alternating step where the primal variables are
updated first.
Args:
compute_cmp_state_kwargs: Keyword arguments to pass to the
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state()`
method.
compute_violations_kwargs: Keyword arguments to pass to the
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_violations()`
method. When
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_violations()`
is implemented, it takes precedence over
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state()`
for the dual update. If not implemented, the violations measured by
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state()`
at the updated primal iterate are used.
Returns:
:py:class:`~cooper.optim.optimizer.RollOut`: A named tuple containing the
following objects:
- loss (:py:class:`~torch.Tensor`):
The most recent loss value at the end of the roll. If
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_violations()`
was used, returns :math:`f(\vx_{t})`. Otherwise, returns the recomputed
loss at the updated primal point :math:`f(\vx_{t+1})`.
- cmp_state (:py:class:`~cooper.CMPState`):
The CMP state at :math:`\vx_{\color{red} t+1}`. Note that if
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_violations()`
is used, the loss at :math:`\vx_{t+1}` is not computed and
``cmp_state.loss`` will be ``None``.
- primal_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The primal Lagrangian store at :math:`\vx_{t}`,
:math:`\vlambda_t` and :math:`\vmu_t`.
- dual_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The dual Lagrangian store at :math:`\vx_{\color{red} t+1}`,
:math:`\vlambda_t` and :math:`\vmu_t`.
"""
if compute_violations_kwargs is None:
compute_violations_kwargs = {}
if compute_cmp_state_kwargs is None:
compute_cmp_state_kwargs = {}
self.zero_grad()
cmp_state = self.cmp.compute_cmp_state(**compute_cmp_state_kwargs)
# Update primal variables only
primal_lagrangian_store = cmp_state.compute_primal_lagrangian()
primal_lagrangian_store.backward()
self.primal_step()
# Update dual variables based on constraint violations at new primal point
with torch.no_grad():
# Note that the dual variables do not intervene in the computation of the
# CMP state. This means we can skip gradient computation wrt the primal
# parameters to avoid wasteful computation, since we will only need the
# gradient wrt the dual variables.
try:
new_cmp_state = self.cmp.compute_violations(**compute_violations_kwargs)
if new_cmp_state.loss is not None:
raise RuntimeError(
"Expected `compute_violations` to not populate the loss. "
"Please provide this value for the `compute_cmp_state` instead."
)
except NotImplementedError:
new_cmp_state = self.cmp.compute_cmp_state(**compute_cmp_state_kwargs)
dual_lagrangian_store = new_cmp_state.compute_dual_lagrangian()
dual_lagrangian_store.backward()
self.dual_step()
loss = new_cmp_state.loss if new_cmp_state.loss is not None else cmp_state.loss
return RollOut(loss, new_cmp_state, primal_lagrangian_store, dual_lagrangian_store)
[docs]
class AlternatingDualPrimalOptimizer(ConstrainedOptimizer):
r"""Optimizes a :py:class:`~cooper.ConstrainedMinimizationProblem`
by performing alternating updates, starting with the dual variables.
According to the choice of primal and dual optimizers, updates are performed as
follows:
.. math::
\vlambda_{t+1} &= \left[ \texttt{dual_optimizer_update} \left( \vlambda_{t},
\nabla_{\vlambda} \Lag_{\text{dual}}(\vx_t, \vlambda, \vmu_t)|_{\vlambda=\vlambda_t}
\right) \right]_+
\vmu_{t+1} &= \texttt{dual_optimizer_update} \left( \vmu_{t}, \nabla_{\vmu}
\Lag_{\text{dual}}(\vx_t, \vlambda_t, \vmu)|_{\vmu=\vmu_t} \right)
\vx_{t+1} &= \texttt{primal_optimizer_update} \left( \vx_{t}, \nabla_{\vx}
\Lag_{\text{primal}}(\vx, \vlambda_{\color{red} t+1}, \vmu_{\color{red} t+1}
)|_{\vx=\vx_t} \right)
For instance, when employing alternating projected gradient descent-ascent on a
:py:class:`~cooper.formulations.Lagrangian` formulation, the updates are as follows:
.. math::
\vlambda_{t+1} &= \left [ \vlambda_t + \eta_{\vlambda} \vg(\vx_t) \right ]_+,
\vmu_{t+1} &= \vmu_t + \eta_{\vmu} \vh(\vx_t),
\vx_{t+1} &= \vx_t - \eta_{\vx} \left [ \nabla_{\vx} f(\vx_t) +
\vlambda_{\color{red} t+1}^\top \nabla_{\vx} \vg(\vx_t) +
\vmu_{\color{red} t+1}^\top \nabla_{\vx} \vh(\vx_t) \right ],
where :math:`\eta_{\vx}`, :math:`\eta_{\vlambda}`, and :math:`\eta_{\vmu}` are step
sizes.
.. note::
Both the primal and dual updates depend on the :py:class:`~cooper.CMPState` at
the current primal iterate :math:`\vx_{t}`. Consequently, although the primal
update uses the updated dual variables :math:`\vlambda_{\color{red} t+1}` and
:math:`\vmu_{\color{red} t+1}`, the :py:class:`~cooper.CMPState` **does not need
to be recomputed after the dual update**. As a result, the computational cost of
this optimizer matches that of the
:py:class:`~cooper.optim.constrained_optimizers.SimultaneousOptimizer`.
"""
[docs]
def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut:
r"""Performs a dual-primal alternating step where the dual variables are
updated first.
Args:
compute_cmp_state_kwargs: Keyword arguments to pass to the
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state()`
method
Returns:
:py:class:`~cooper.optim.optimizer.RollOut`: A named tuple containing the
following objects:
- loss (:py:class:`~torch.Tensor`):
The loss value computed during the roll, :math:`f(\vx_{t})`.
- cmp_state (:py:class:`~cooper.CMPState`):
The CMP state at :math:`\vx_{t}`.
- primal_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The primal Lagrangian store at :math:`\vx_{t}`,
:math:`\vlambda_{\color{red} t+1}` and :math:`\vmu_{\color{red} t+1}`.
- dual_lagrangian_store (:py:class:`~cooper.LagrangianStore`):
The dual Lagrangian store at :math:`\vx_{t}`, :math:`\vlambda_t` and
:math:`\vmu_t`.
"""
if compute_cmp_state_kwargs is None:
compute_cmp_state_kwargs = {}
self.zero_grad()
cmp_state = self.cmp.compute_cmp_state(**compute_cmp_state_kwargs)
# Update dual variables only
dual_lagrangian_store = cmp_state.compute_dual_lagrangian()
dual_lagrangian_store.backward()
self.dual_step()
# Update primal variables based on the Lagrangian at the new dual point, and the
# objective and constraint violations measured at the old primal point.
primal_lagrangian_store = cmp_state.compute_primal_lagrangian()
primal_lagrangian_store.backward()
self.primal_step()
return RollOut(cmp_state.loss, cmp_state, primal_lagrangian_store, dual_lagrangian_store)