# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.
import abc
from typing import Any, NamedTuple, Optional, TypedDict
import torch
from cooper.cmp import CMPState, ConstrainedMinimizationProblem, LagrangianStore
from cooper.utils import OneOrSequence, ensure_sequence
[docs]
class CooperOptimizerState(TypedDict):
r"""Stores the state of a :py:class:`~cooper.optim.CooperOptimizer`.
Args:
primal_optimizer_states: List of primal optimizer ``state_dict``\s.
dual_optimizer_states: List of dual optimizer ``state_dict``\s. If the optimizer
is an unconstrained optimizer, this field is set to ``None``.
"""
primal_optimizer_states: list[dict]
dual_optimizer_states: Optional[list[dict]]
[docs]
class RollOut(NamedTuple):
"""Stores the output of a call to :py:meth:`~cooper.optim.CooperOptimizer.roll()`.
Args:
loss (:py:class:`torch.Tensor`): Value of the objective function.
cmp_state (:py:class:`~cooper.cmp.CMPState`): State of the CMP.
primal_lagrangian_store (:py:class:`~cooper.LagrangianStore`): LagrangianStore for the primal Lagrangian.
dual_lagrangian_store (:py:class:`~cooper.LagrangianStore`): LagrangianStore for the dual Lagrangian.
"""
loss: torch.Tensor
cmp_state: CMPState
primal_lagrangian_store: LagrangianStore
dual_lagrangian_store: LagrangianStore
[docs]
class CooperOptimizer(abc.ABC):
r"""Base class for :py:class:`~cooper.optim.constrained_optimizer.ConstrainedOptimizer`
and :py:class:`~cooper.optim.UnconstrainedOptimizer`\s.
Args:
cmp: The constrained minimization problem to be optimized. Providing the CMP
as an argument for the constructor allows the optimizer to call the
:py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state`
method within the :py:meth:`~cooper.optim.cooper_optimizer.CooperOptimizer.roll`
method. Additionally, in the case of a constrained optimizer, the CMP
enables access to the multipliers'
:py:meth:`~cooper.multipliers.Multiplier.post_step_` method which must be
called after the multiplier update.
primal_optimizers: Optimizer(s) for the primal variables (e.g. the weights of
a model). The primal parameters can be partitioned into multiple optimizers,
in this case ``primal_optimizers`` accepts a list of
:py:class:`torch.optim.Optimizer`\s.
dual_optimizers: Optimizer(s) for the dual variables (e.g. the Lagrange
multipliers associated with the constraints). A sequence of
:py:class:`torch.optim.Optimizer`\s can be passed to handle the case of
several :py:class:`~cooper.constraints.Constraint`\s.
"""
def __init__(
self,
cmp: ConstrainedMinimizationProblem,
primal_optimizers: OneOrSequence[torch.optim.Optimizer],
dual_optimizers: Optional[OneOrSequence[torch.optim.Optimizer]] = None,
) -> None:
self.cmp = cmp
self.primal_optimizers = ensure_sequence(primal_optimizers)
self.dual_optimizers = ensure_sequence(dual_optimizers)
[docs]
def zero_grad(self) -> None:
r"""Sets the gradients of all optimized :py:class:`~torch.nn.parameter.Parameter`\s
to zero. This includes both the primal and dual variables.
"""
for primal_optimizer in self.primal_optimizers:
# Prior to PyTorch 2.0, set_to_none=False was the default behavior.
# The default behavior was changed to set_to_none=True in PyTorch 2.0.
# We set set_to_none=True explicitly to ensure compatibility with both versions.
primal_optimizer.zero_grad(set_to_none=True)
if self.dual_optimizers is not None:
for dual_optimizer in self.dual_optimizers:
dual_optimizer.zero_grad(set_to_none=True)
[docs]
@torch.no_grad()
def primal_step(self) -> None:
"""Performs a gradient step on the parameters associated with the primal variables."""
for primal_optimizer in self.primal_optimizers:
primal_optimizer.step()
[docs]
def state_dict(self) -> CooperOptimizerState:
r"""Returns the state of the optimizer as a
:py:class:`~cooper.optim.cooper_optimizer.CooperOptimizerState`. This method
relies on the internal :py:meth:`~torch.optim.Optimizer.state_dict` method of
the corresponding primal or dual optimizers.
"""
primal_optimizer_states = [optimizer.state_dict() for optimizer in self.primal_optimizers]
dual_optimizer_states = None
if self.dual_optimizers is not None:
dual_optimizer_states = [optimizer.state_dict() for optimizer in self.dual_optimizers]
return CooperOptimizerState(
primal_optimizer_states=primal_optimizer_states, dual_optimizer_states=dual_optimizer_states
)
[docs]
def load_state_dict(self, state: CooperOptimizerState) -> None:
"""Loads the optimizer state from the given state dictionary.
Args:
state: A dictionary containing the optimizer state.
Raises:
ValueError: If the number of primal optimizers does not match the number of primal optimizer states.
ValueError: If the number of dual optimizers does not match the number of dual optimizer states.
ValueError: If ``dual_optimizer_states`` is present in the state dict but ``dual_optimizers`` is None.
"""
if len(state["primal_optimizer_states"]) != len(self.primal_optimizers):
raise ValueError("The number of primal optimizers does not match the number of primal optimizer states.")
if self.dual_optimizers is None:
if state["dual_optimizer_states"] is not None:
raise ValueError(
"Optimizer state dict contains ``dual_optimizer_states`` but ``dual_optimizers`` is None."
)
elif len(state["dual_optimizer_states"]) != len(self.dual_optimizers):
raise ValueError("The number of dual optimizers does not match the number of dual optimizer states.")
for primal_optimizer, primal_optimizer_state in zip(self.primal_optimizers, state["primal_optimizer_states"]):
primal_optimizer.load_state_dict(primal_optimizer_state)
if self.dual_optimizers is not None:
for dual_optimizer, dual_optimizer_state in zip(self.dual_optimizers, state["dual_optimizer_states"]):
dual_optimizer.load_state_dict(dual_optimizer_state)
[docs]
@abc.abstractmethod
def roll(self, *args: Any, **kwargs: Any) -> RollOut:
"""Evaluates the objective function and performs a gradient update on the parameters."""