Source code for cooper.optim.unconstrained_optimizer

# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.

"""Implementation of the :py:class:`UnconstrainedOptimizer` class."""

from typing import Optional

from cooper.cmp import LagrangianStore
from cooper.optim.optimizer import CooperOptimizer, RollOut


[docs] class UnconstrainedOptimizer(CooperOptimizer): r"""Wraps a (sequence of) :py:class:`torch.optim.Optimizer`\s to enable handling unconstrained minimization problems in a way that is consistent with :py:class:`~cooper.optim.constrained_optimizers.ConstrainedOptimizer`\s. Args: cmp: The constrained minimization problem to be optimized. Providing the CMP as an argument for the constructor allows the optimizer to call the :py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state` method within the :py:meth:`~cooper.optim.cooper_optimizer.CooperOptimizer.roll` method. primal_optimizers: Optimizer(s) for the primal variables (e.g. the weights of a model). The primal parameters can be partitioned into multiple optimizers, in this case ``primal_optimizers`` accepts a list of :py:class:`torch.optim.Optimizer`\s. """
[docs] def roll(self, compute_cmp_state_kwargs: Optional[dict] = None) -> RollOut: """Evaluates the objective function and performs a gradient update on the parameters. Args: compute_cmp_state_kwargs: Keyword arguments to pass to the :py:meth:`~cooper.ConstrainedMinimizationProblem.compute_cmp_state()` method. Since this is an unconstrained optimizer, the CMPState will just contain the loss. """ if compute_cmp_state_kwargs is None: compute_cmp_state_kwargs = {} self.zero_grad() cmp_state = self.cmp.compute_cmp_state(**compute_cmp_state_kwargs) lagrangian_store = cmp_state.compute_primal_lagrangian() lagrangian_store.backward() self.primal_step() # The dual lagrangian store is empty for unconstrained problems dual_lagrangian_store = LagrangianStore() return RollOut(cmp_state.loss, cmp_state, lagrangian_store, dual_lagrangian_store)