# coding: utf8
"""
Implementation of :py:class:`ConstrainedOptimizer` class, which has 2 main
methods:
- :py:meth:`~ConstrainedOptimizer.zero_grad`
- :py:meth:`~ConstrainedOptimizer.step`
"""
from typing import Callable, Optional
import torch
from .problem import CMPState, Formulation
[docs]class ConstrainedOptimizer:
"""
Optimizes a :py:class:`~cooper.problem.ConstrainedMinimizationProblem`
given its :py:class:`~cooper.problem.Formulation`.
A ``ConstrainedOptimizer`` includes one or two
:class:`torch.optim.Optimizer`\\s, for the primal and dual variables
associated with the ``Formulation``, respectively.
A ``ConstrainedOptimizer`` can be used on constrained or unconstrained
``ConstrainedMinimizationProblem``\\s. Please refer to the documentation
of the :py:class:`~cooper.problem.ConstrainedMinimizationProblem` and
:py:class:`~cooper.problem.Formulation` classes for further details on
handling unconstrained problems.
Args:
formulation: ``Formulation`` of the ``ConstrainedMinimizationProblem``
to be optimized.
primal_optimizer: Fully instantiated ``torch.optim.Optimizer`` used
to optimize the primal parameters (e.g. model parameters).
dual_optimizer: Partially instantiated ``torch.optim.Optimizer``
used to optimize the dual variables (e.g. Lagrange multipliers).
Defaults to None.
When dealing with an unconstrained problem, should be set to None.
dual_scheduler: Partially instantiated
``torch.optim.lr_scheduler._LRScheduler``
used to schedule the learning rate of the dual variables.
Defaults to None.
When dealing with an unconstrained problem, should be set to None.
alternating: Whether to alternate parameter updates between primal and
dual parameters. Otherwise, do simultaneous parameter updates.
Defaults to False.
dual_restarts: If True, perform "restarts" on the Lagrange
multipliers associated with inequality constraints: whenever the
constraint is satisfied, directly set the multiplier to zero.
Defaults to False.
"""
def __init__(
self,
formulation: Formulation,
primal_optimizer: torch.optim.Optimizer,
dual_optimizer: Optional[torch.optim.Optimizer] = None,
dual_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
alternating: bool = False,
dual_restarts: bool = False,
):
self.formulation = formulation
self.cmp = self.formulation.cmp
self.primal_optimizer = primal_optimizer
self.dual_optimizer = dual_optimizer
self.dual_scheduler = dual_scheduler
self.alternating = alternating
self.dual_restarts = dual_restarts
self.sanity_checks()
[docs] def sanity_checks(self):
"""
Perform sanity checks on the initialization of ``ConstrainedOptimizer``.
Raises:
NotImplementedError: The ``Formulation`` has an augmented Lagrangian
coefficient and ``primal_optimizer`` has an ``extrapolation``
function. This is not supported because of possible unexpected
behavior.
RuntimeError: The ``primal_optimizer`` has an ``extrapolation``
function and ``alternating`` was set to True. Mixing
extrapolation and alternating updates is not supported.
RuntimeError: a ``dual_optimizer`` was provided but the
``ConstrainedMinimizationProblem`` of formulation was
unconstrained. There are no dual variables to optimize.
RuntimeError: a ``dual_scheduler`` was provided but the
``ConstrainedMinimizationProblem`` of formulation was
unconstrained. There are no dual variables and no
``dual_optimizer`` for learning rate scheduling.
RuntimeError: a ``dual_scheduler`` was provided but no
``dual_optimizer`` was provided. Can not schedule the learning
rate of an unknown optimizer.
RuntimeError: the considered ``ConstrainedMinimizationProblem`` is
unconstrained, but the provided ``primal_optimizer`` has an
``extrapolation`` function. This is not supported because of
unexpected behavior when using extrapolation to update the
primal parameters without any dual parameters.
RuntimeError: One of ``primal_optimizer`` or ``dual_optimizer`` has
an extrapolation function while the other does not.
Extrapolation on only one player is not supported.
"""
is_alternating = self.alternating
is_aug_lag = hasattr(self.formulation, "aug_lag_coefficient") and (
self.formulation.aug_lag_coefficient > 0
)
# We assume that both optimizers agree on whether to use extrapolation
# or not, so we use the primal optimizer as reference for deciding
# whether to use extrapolation. See check below for matching
# extrapolation behavior.
self.is_extrapolation = hasattr(self.primal_optimizer, "extrapolation")
if is_aug_lag and self.is_extrapolation:
raise NotImplementedError(
"""It is currently not possible to use extrapolation and an
augmented Lagrangian formulation"""
)
if is_alternating and self.is_extrapolation:
raise RuntimeError(
"""Should not use extrapolation and alternating updates
simultaneously. Please disable one of these two modes."""
)
if not (self.cmp.is_constrained) and (self.dual_optimizer is not None):
raise RuntimeError(
"""Provided a dual optimizer, but the `Problem` class claims to
be unconstrained."""
)
if self.dual_scheduler is not None:
if not (self.cmp.is_constrained):
raise RuntimeError(
"""A dual scheduler was provided, but the `Problem` class
claims to be unconstrained."""
)
if self.dual_optimizer is None:
raise RuntimeError(
"""A dual scheduler was provided, but no dual optimizer
was provided."""
)
if not (self.cmp.is_constrained) and self.is_extrapolation:
raise RuntimeError(
"""Using an extrapolating optimizer an unconstrained problem
might result in unexpected behavior. Consider using a
non-extrapolating optimizer instead."""
)
if hasattr(self.primal_optimizer, "extrapolation") != hasattr(
self.dual_optimizer, "extrapolation"
):
raise RuntimeError(
"""Primal and dual optimizers do not agree on whether to use
extrapolation or not."""
)
[docs] def step(
self,
closure: Optional[Callable[..., CMPState]] = None,
*closure_args,
**closure_kwargs
):
"""
Performs a single optimization step on both the primal and dual
variables. If ``dual_scheduler`` is provided, a scheduler step is
performed on the learning rate of the ``dual_optimizer``.
Args:
closure: Closure ``Callable`` required for re-evaluating the
objective and constraints when performing alternating or
extrapolating updates.
Defaults to None.
*closure_args: Arguments to be passed to the closure function
when re-evaluating.
**closure_kwargs: Keyword arguments to be passed to the closure
function when re-evaluating.
"""
if self.cmp.is_constrained and not hasattr(self.dual_optimizer, "param_groups"):
assert self.dual_optimizer is not None and callable(self.dual_optimizer)
# Checks if needed and instantiates dual_optimizer
self.dual_optimizer = self.dual_optimizer(self.formulation.dual_parameters)
if self.dual_scheduler is not None:
assert callable(self.dual_scheduler), "dual_scheduler must be callable"
# Instantiates the dual_scheduler
self.dual_scheduler = self.dual_scheduler(self.dual_optimizer)
if self.is_extrapolation or self.alternating:
assert closure is not None
if self.is_extrapolation:
# Store parameter copy and compute t+1/2 iterates
self.primal_optimizer.extrapolation() # type: ignore
if self.cmp.is_constrained:
# Call to dual_step flips sign of gradients, then triggers call
# to dual_optimizer.extrapolation and projects dual variables
self.dual_step(call_extrapolation=True)
# Zero gradients and recompute loss at t+1/2
self.zero_grad()
# For extrapolation, we need closure args here as the parameter
# values will have changed in the update applied on the
# extrapolation step
lagrangian = self.formulation.composite_objective(
closure, *closure_args, **closure_kwargs
) # type: ignore
# Populate gradients at extrapolation point
self.formulation.custom_backward(lagrangian)
# After this, the calls to `step` will update the stored copies with
# the newly computed gradients
self.primal_optimizer.step()
if self.cmp.is_constrained:
self.dual_step()
if self.dual_scheduler is not None:
# Do a step on the dual scheduler after the actual step on
# the dual parameters. Intermediate updates that take
# place inside the extrapolation process do not perform a
# call to the scheduler's step method
self.dual_scheduler.step()
else:
self.primal_optimizer.step()
if self.cmp.is_constrained:
if self.alternating:
# TODO: add test for this
# Once having updated primal parameters, re-compute gradient
# Skip gradient wrt model parameters to avoid wasteful
# computation, as we only need gradient wrt multipliers.
with torch.no_grad():
assert closure is not None
self.cmp.state = closure(*closure_args, **closure_kwargs)
lagrangian = self.formulation.composite_objective(self.cmp) # type: ignore
# Zero-out gradients for dual variables since they were
# already populated earlier.
# We also zero-out primal gradients for safety although not
# really necessary.
self.zero_grad(ignore_primal=False, ignore_dual=False)
# Not passing lagrangian since we only want to update the
# gradients for the dual variables
self.formulation._populate_gradients(
lagrangian=None, ignore_primal=True
)
self.dual_step()
if self.dual_scheduler is not None:
self.dual_scheduler.step()
def dual_step(self, call_extrapolation=False):
# Flip gradients for multipliers to perform ascent.
# We only do the flipping *right before* applying the optimizer step to
# avoid accidental double sign flips.
for multiplier in self.formulation.state():
if multiplier is not None:
multiplier.grad.mul_(-1.0)
# Update multipliers based on current constraint violations (gradients)
if call_extrapolation:
self.dual_optimizer.extrapolation()
else:
self.dual_optimizer.step()
if self.formulation.ineq_multipliers is not None:
if self.dual_restarts:
# "Reset" value of inequality multipliers to zero as soon as
# solution becomes feasible
self.restart_dual_variables()
# Apply projection step to inequality multipliers
self.formulation.ineq_multipliers.project_()
def restart_dual_variables(self):
# Call to formulation._populate_gradients has already flipped sign
# A currently *positive* gradient means original defect is negative, so
# the constraint is being satisfied.
# The code below still works in the case of proxy constraints, since the
# multiplier updates are computed based on *non-proxy* constraints
feasible_filter = self.formulation.ineq_multipliers.weight.grad > 0
self.formulation.ineq_multipliers.weight.grad[feasible_filter] = 0.0
self.formulation.ineq_multipliers.weight.data[feasible_filter] = 0.0
[docs] def zero_grad(self, ignore_primal: bool = False, ignore_dual: bool = False):
"""
Sets the gradients of all optimized
:py:class:`~torch.nn.parameter.Parameter`\\s to zero. This includes both
the primal and dual variables.
Args:
ignore_primal: If True, the gradients of the primal variables will
not be zeroed. Defaults to False.
ignore_dual: If True, the gradients of the dual variables will not
be zeroed. Defaults to False.
"""
if not ignore_primal:
self.primal_optimizer.zero_grad()
if not ignore_dual:
if self.formulation.is_state_created:
if self.dual_optimizer is None:
raise RuntimeError(
"Requested zeroing gradients but dual_optimizer is None."
)
else:
self.dual_optimizer.zero_grad()