Source code for cooper.optim.constrained_optimizers.simultaneous_optimizer

# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.

"""Implementation of the :py:class:`SimultaneousOptimizer` class."""

from typing import Optional

from cooper.optim.constrained_optimizers.constrained_optimizer import ConstrainedOptimizer
from cooper.optim.optimizer import RollOut


[docs] class SimultaneousOptimizer(ConstrainedOptimizer): r"""Optimizes a :py:class:`~cooper.ConstrainedMinimizationProblem` by performing simultaneous gradient updates to the primal and dual variables. According to the choice of primal and dual optimizers, the updates are performed as follows: .. math:: \vx_{t+1} &= \texttt{primal_optimizer_update} \left( \vx_{t}, \nabla_{\vx} \Lag_{\text{primal}}(\vx, \vlambda_t, \vmu_t)|_{\vx=\vx_t} \right) \vlambda_{t+1} &= \left[ \texttt{dual_optimizer_update} \left( \vlambda_{t}, \nabla_{\vlambda} \Lag_{\text{dual}}({\vx_{t}}, \vlambda, \vmu_t)|_{\vlambda=\vlambda_t} \right) \right]_+ \vmu_{t+1} &= \texttt{dual_optimizer_update} \left( \vmu_{t}, \nabla_{\vmu} \Lag_{\text{dual}}({\vx_{t}}, \vlambda_t, \vmu)|_{\vmu=\vmu_t} \right) For instance, when the primal/dual updates are gradient descent/ascent on a :py:class:`~cooper.formulations.Lagrangian` formulation, the updates are as follows: .. math:: \vx_{t+1} &= \vx_t - \eta_{\vx} \left [ \nabla_{\vx} f(\vx_t) + \vlambda_t^\top \nabla_{\vx} \vg(\vx_t) + \vmu_t^\top \nabla_{\vx} \vh(\vx_t) \right ], \vlambda_{t+1} &= \left [ \vlambda_t + \eta_{\vlambda} \vg(\vx_t) \right ]_+, \vmu_{t+1} &= \vmu_t + \eta_{\vmu} \vh(\vx_t), where :math:`\eta_{\vx}`, :math:`\eta_{\vlambda}`, and :math:`\eta_{\vmu}` are step sizes. """
[docs] def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut: """Evaluates the :py:class:`~cooper.CMPState` and performs a simultaneous primal-dual update. Args: compute_cmp_state_kwargs: Keyword arguments to pass to the :py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state` method. """ if compute_cmp_state_kwargs is None: compute_cmp_state_kwargs = {} self.zero_grad() cmp_state = self.cmp.compute_cmp_state(**compute_cmp_state_kwargs) # TODO (gallego-posada): The current design goes over the constraints twice. We # could reduce overhead by writing a dedicated compute_lagrangian function for # the simultaneous updates primal_lagrangian_store = cmp_state.compute_primal_lagrangian() dual_lagrangian_store = cmp_state.compute_dual_lagrangian() # The order of the following operations is not important # because the primal and dual lagrangians have independent gradients primal_lagrangian_store.backward() dual_lagrangian_store.backward() # The order of the following operations is not important too # because they update independent primal and dual parameters self.primal_step() self.dual_step() return RollOut(cmp_state.loss, cmp_state, primal_lagrangian_store, dual_lagrangian_store)