Source code for cooper.optim.torch_optimizers.nupi_optimizer

# Copyright (C) 2025 The Cooper Developers.
# Licensed under the MIT License.

"""The nuPI optimizer is a first-order optimization algorithm proposed in the ICML 2024
paper *On PI controllers for updating Lagrange multipliers in constrained optimization*
by Motahareh Sohrabi, Juan Ramirez, Tianyue H. Zhang, Simon Lacoste-Julien, and
Jose Gallego-Posada.
"""

import warnings
from collections.abc import Callable, Iterable
from enum import Enum
from typing import Optional

import torch


[docs] class nuPIInitType(Enum): r"""nuPI initialization types. This is used to determine how to initialize the error and derivative terms of the nuPI controller. The initialization scheme ``SGD`` ensures that the first step of ``nuPI(KP, KI)`` is equivalent to SGD with learning rate :math:`\eta \times K_I`. The ``ZEROS`` scheme yields a first step which corresponds to SGD with a learning rate of :math:`\eta \times (K_P + K_I)`. """ ZEROS = 0 SGD = 1
[docs] class nuPI(torch.optim.Optimizer):
[docs] def __init__( self, params: Iterable[torch.Tensor], lr: float, weight_decay: Optional[float] = 0.0, Kp: Optional[torch.Tensor] = 0.0, Ki: Optional[torch.Tensor] = 1.0, ema_nu: float = 0.0, init_type: nuPIInitType = nuPIInitType.SGD, maximize: bool = False, ) -> None: r"""Implements the ``nuPI`` controller as a PyTorch optimizer. Controllers are designed to guide a system toward a desired state by adjusting a control variable. This is achieved by measuring the error, which is the difference between the desired and current states, and using this error to modify the control variable, thereby influencing the system. For this controller, the error signal is derived from the gradient of a loss function :math:`L` being optimized with respect to a parameter :math:`\vtheta`. Here, :math:`\vtheta` acts as the control variable, while the **gradient** of :math:`L` serves as the error signal, defined as :math:`\ve_t = \nabla L_t(\vtheta_t)`. The control objective of setting :math:`\nabla L_t(\vtheta_t) = 0` corresponds to finding a stationary point of the loss function, thereby minimizing (or maximizing) it. .. note:: When applied to the Lagrange multipliers of a constrained minimization problem, the control state :math:`\nabla L_t(\vtheta_t)` corresponds to the gradient of the Lagrangian function with respect to the multipliers (e.g., :math:`\nabla_{\vlambda} \Lag(\vx, \vlambda) = \vg(\vx)` for inequality-constrained problems). Setting this gradient to (less than or equal to) zero corresponds to finding a point that satisfies the constraints. The ``nuPI`` controller updates parameters as follows: .. math:: \vxi_t &= \nu \vxi_{t-1} + (1 - \nu) \ve_t, \\ \vtheta_1 &= \vtheta_0 - \eta (K_P \vxi_0 + K_I \ve_0), \\ \vtheta_{t+1} &= \vtheta_t - \eta (K_I \ve_t + K_P (\vxi_t - \vxi_{t-1})) Here, :math:`\vxi_t` is a smoothed version of the error signal (:math:`\ve_t`), using an exponential moving average (EMA) with coefficient :math:`\nu`. :math:`K_P` and :math:`K_I` are the proportional and integral gains, respectively, while the learning rate :math:`\eta` is kept separate to allow comparison with other optimizers. Weight decay is applied based only on the error signal :math:`\ve_t`, following a similar approach to PyTorch's AdamW optimizer. When ``maximize=False``, the parameter update is multiplied by :math:`-1` before being applied. **Initialization Schemes**: The initialization of the ``nuPI`` controller requires specifying the initial smoothed error signal, :math:`\vxi_{-1}`, which impacts the first parameter update. Two initialization schemes are available: - ``nuPIInitType.ZEROS``: Initializes :math:`\vxi_{-1} = \vzero`. The first update rule becomes: .. math:: \vtheta_1 = \vtheta_0 - \eta (K_P \ve_0 + K_I \ve_0) = \vtheta_0 - \eta (K_P + K_I) \ve_0. - ``nuPIInitType.SGD``: Initializes :math:`\vxi_{-1} = \ve_0`, producing a first step identical to SGD: .. math:: \vxi_0 &= \ve_0, \\ \vtheta_1 &= \vtheta_0 - \eta (K_P \ve_0 + K_I \ve_0) = \vtheta_0 - \eta K_I \ve_0. .. note:: nuPI(:math:`\eta`, :math:`K_P=0`, :math:`K_I=1`, :math:`\nu=0`) corresponds to SGD with learning rate :math:`\eta`. nuPI(:math:`\eta`, :math:`K_P=1`, :math:`K_I=1`, :math:`\nu=0`) corresponds to the optimistic gradient method :cite:p:`popov1980modification`. Args: params: iterable of parameters to optimize, or dicts defining parameter groups. lr: learning rate. weight_decay: weight decay (L2 penalty). Defaults to 0. Kp: proportional gain. Defaults to 0. Ki: integral gain. Defaults to 1. ema_nu: EMA coefficient for the smoothed error signal. Defaults to 0, meaning no smoothing is applied. init_type: initialization scheme for :math:`\vxi_{-1}`. Defaults to ``nuPIInitType.SGD``, which matches the first step of SGD. maximize: whether to maximize the objective with respect to the parameters instead of minimizing. Defaults to ``False``. Raises: ValueError: If the learning rate, or weight decay is negative. ValueError: If the EMA coefficient is not in the range :math:`(-1, 1)`. ValueError: If the initialization type is invalid. NotImplementedError: If multiple parameter groups are used with non-scalar proportional and integral gains. Warnings: If a negative proportional or integral gain is used. If both proportional and integral gains are zero. If the EMA coefficient is negative. """ if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not -1 < ema_nu < 1.0: raise ValueError(f"Invalid nu value: {ema_nu}") if init_type not in {nuPIInitType.ZEROS, nuPIInitType.SGD}: raise ValueError(f"Invalid init_type: {init_type}") if not isinstance(Kp, torch.Tensor): Kp = torch.tensor(Kp) if not isinstance(Ki, torch.Tensor): Ki = torch.tensor(Ki) if torch.any(Kp < 0.0): warnings.warn(f"Using a negative Kp coefficient: {Kp}") if torch.any(Ki < 0.0): warnings.warn(f"Using a negative Ki coefficient: {Kp}") if torch.all(Kp == 0.0) and torch.all(Ki == 0.0): warnings.warn("All PI coefficients are zero") if ema_nu < 0: warnings.warn("nuPI optimizer instantiated with negative EMA coefficient") defaults = { "lr": lr, "weight_decay": weight_decay, "Kp": Kp, "Ki": Ki, "ema_nu": ema_nu, "maximize": maximize, "init_type": init_type, } super().__init__(params, defaults) if len(self.param_groups) > 1 and Kp.shape != torch.Size([1]): raise NotImplementedError("When using multiple parameter groups, Kp and Ki must be scalars")
@staticmethod def disambiguate_update_function(is_grad_sparse: bool, init_type: nuPIInitType) -> Callable: if is_grad_sparse: if init_type == nuPIInitType.ZEROS: return _sparse_nupi_zero_init return _sparse_nupi_sgd_init if init_type == nuPIInitType.ZEROS: return _nupi_zero_init return _nupi_sgd_init
[docs] @torch.no_grad() def step(self, closure: Optional[Callable] = None) -> Optional[float]: """Performs a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue update_function = self.disambiguate_update_function(p.grad.is_sparse, group["init_type"]) update_function( param=p, state=self.state[p], lr=group["lr"], weight_decay=group["weight_decay"], Kp=group["Kp"], Ki=group["Ki"], ema_nu=group["ema_nu"], maximize=group["maximize"], ) return loss
def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) for group in self.param_groups: for p in group["params"]: state = self.state[p] if "needs_error_initialization_mask" in state: # Need to convert to bool explicitly since torch might have loaded # it as a float tensor and the `torch.where` calls would fail state["needs_error_initialization_mask"] = state["needs_error_initialization_mask"].bool()
def _nupi_zero_init( param: torch.Tensor, state: dict, lr: float, weight_decay: float, Kp: torch.Tensor, Ki: torch.Tensor, ema_nu: float, maximize: bool, ) -> None: """Applies a nuPI step update to `param`.""" error = param.grad detached_error = error.clone().detach() xit_m1_coef = Kp * (1 - ema_nu) if "xi" not in state and xit_m1_coef.ne(0).any(): state["xi"] = torch.zeros_like(param) nupi_update = torch.zeros_like(param) et_coef = Ki + Kp * (1 - ema_nu) if et_coef.ne(0).any(): nupi_update.add_(detached_error.mul(et_coef)) if xit_m1_coef.ne(0).any(): nupi_update.sub_(state["xi"].mul(xit_m1_coef)) # Weight decay is applied after estimating the error change, similar to AdamW. # See https://arxiv.org/abs/1711.05101 for details. if weight_decay != 0: nupi_update.add_(param, alpha=-weight_decay if maximize else weight_decay) alpha = lr if maximize else -lr param.add_(nupi_update, alpha=alpha) if "xi" in state and xit_m1_coef.ne(0).any(): state["xi"].mul_(ema_nu).add_(detached_error, alpha=1 - ema_nu) def _sparse_nupi_zero_init( param: torch.Tensor, state: dict, lr: float, weight_decay: float, Kp: float, Ki: float, ema_nu: float, maximize: bool, ) -> None: """Analogous to _nupi but with support for sparse gradients. This function implements updates based on a zero initialization scheme. """ error = param.grad error = error.coalesce() # the update is non-linear so indices must be unique error_indices = error.indices() detached_error_values = error._values().clone().detach() if detached_error_values.numel() == 0: # Skip update for empty grad return Ki_values = Ki[error_indices] if Ki.numel() > 1 else Ki Kp_values = Kp[error_indices] if Kp.numel() > 1 else Kp nupi_update_values = torch.zeros_like(detached_error_values) et_coef = Ki_values + Kp_values * (1 - ema_nu) xit_m1_coef = Kp_values * (1 - ema_nu) if "xi" not in state and (xit_m1_coef).ne(0).any(): state["xi"] = torch.zeros_like(param) et_coef_values = et_coef[error_indices] if et_coef.numel() > 1 else et_coef xit_m1_coef_values = xit_m1_coef[error_indices] if xit_m1_coef.numel() > 1 else xit_m1_coef if et_coef_values.ne(0).any(): nupi_update_values.add_(detached_error_values.mul(et_coef_values)) if xit_m1_coef_values.ne(0).any(): xi_values = state["xi"].sparse_mask(error)._values() nupi_update_values.sub_(xi_values.mul(xit_m1_coef)) nupi_update = torch.sparse_coo_tensor(error_indices, nupi_update_values, size=param.shape) # Weight decay is applied after estimating the delta and curvature, similar to # AdamW. See https://arxiv.org/abs/1711.05101 for details. if weight_decay != 0: observed_params = param.sparse_mask(error) nupi_update.add_(observed_params, alpha=-weight_decay if maximize else weight_decay) alpha = lr if maximize else -lr param.add_(nupi_update, alpha=alpha) if "xi" in state and xit_m1_coef_values.ne(0).any(): state["xi"][error_indices] = xi_values.mul(ema_nu).add(detached_error_values, alpha=1 - ema_nu) def _nupi_sgd_init( param: torch.Tensor, state: dict, lr: float, weight_decay: float, Kp: torch.Tensor, Ki: torch.Tensor, ema_nu: float, maximize: bool, ) -> None: """Applies a nuPI step update to `param`.""" error = param.grad detached_error = error.clone().detach() uses_ki_term = Ki.ne(0).any() uses_kp_term = (Kp * (1 - ema_nu)).ne(0).any() nupi_update = torch.zeros_like(param) if uses_ki_term: nupi_update.add_(error.mul(Ki)) if uses_kp_term: if "xi" in state: kp_term_contribution = (1 - ema_nu) * (detached_error - state["xi"]) nupi_update.add_(kp_term_contribution.mul(Kp)) else: # First step is designed to match GD, so no need to add contribution here pass # Weight decay is applied after estimating the error change, similar to AdamW. # See https://arxiv.org/abs/1711.05101 for details. if weight_decay != 0: nupi_update.add_(param, alpha=-weight_decay if maximize else weight_decay) alpha = lr if maximize else -lr param.add_(nupi_update, alpha=alpha) if uses_kp_term: if "xi" not in state: # Initialize xi_0 = 0 state["xi"] = torch.zeros_like(param) else: state["xi"].mul_(ema_nu).add_(detached_error, alpha=1 - ema_nu) def _sparse_nupi_sgd_init( param: torch.Tensor, state: dict, lr: float, weight_decay: float, Kp: float, Ki: float, ema_nu: float, maximize: bool, ) -> None: """Analogous to _nupi but with support for sparse gradients. This function implements updates based on a "SGD" initialization scheme that makes the first step of nuPI (on each coordinate) match that of SGD. """ error = param.grad error = error.coalesce() # the update is non-linear so indices must be unique error_indices = error.indices() detached_error_values = error._values().clone().detach() if detached_error_values.numel() == 0: # Skip update for empty grad return filtered_Ki_values = Ki[error_indices] if Ki.numel() > 1 else Ki filtered_Kp_values = Kp[error_indices] if Kp.numel() > 1 else Kp uses_ki_term = filtered_Ki_values.ne(0).any() uses_kp_term = (filtered_Kp_values * (1 - ema_nu)).ne(0).any() nupi_update_values = torch.zeros_like(detached_error_values) if "xi" not in state and uses_kp_term: state["xi"] = torch.zeros_like(param) state["needs_error_initialization_mask"] = torch.ones_like(param, dtype=torch.bool) if uses_ki_term: nupi_update_values.add_(detached_error_values.mul(filtered_Ki_values)) if uses_kp_term: previous_xi_values = state["xi"].sparse_mask(error)._values() proportional_term_contribution = torch.where( state["needs_error_initialization_mask"].sparse_mask(error)._values(), torch.zeros_like(detached_error_values), # If state has not been initialized, xi_0 = 0 (1 - ema_nu) * (detached_error_values - previous_xi_values), # Else, we use recursive update ) nupi_update_values.add_(proportional_term_contribution.mul(filtered_Kp_values)) nupi_update = torch.sparse_coo_tensor(error_indices, nupi_update_values, size=param.shape) # Weight decay is applied after estimating the delta and curvature, similar to # AdamW. See https://arxiv.org/abs/1711.05101 for details. if weight_decay != 0: observed_params = param.sparse_mask(error) nupi_update.add_(observed_params, alpha=-weight_decay if maximize else weight_decay) alpha = lr if maximize else -lr param.add_(nupi_update, alpha=alpha) if "xi" in state and uses_kp_term: state["xi"][error_indices] = previous_xi_values.mul(ema_nu).add(detached_error_values, alpha=1 - ema_nu) state["needs_error_initialization_mask"][error_indices] *= False