Source code for cooper.optim

"""Extrapolation Optimizers and functions for partial instantiation of dual
optimizers and schedulers"""

import functools
import math
from collections.abc import Iterable
from typing import Callable, List, Tuple, Type, no_type_check

import torch
from torch.optim.lr_scheduler import _LRScheduler


@no_type_check
def partial_optimizer(optim_cls: Type[torch.optim.Optimizer], **optim_kwargs):
    """
    Partially instantiates an optimizer class. This approach is preferred over
    :py:func:`functools.partial` since the returned value is an optimizer
    class whose attributes can be inspected and which can be further
    instantiated.

    Args:
        optim_cls: Pytorch optimizer class to be partially instantiated.
        **optim_kwargs: Keyword arguments for optimizer hyperparemeters.
    """

    class PartialOptimizer(optim_cls):
        __init__ = functools.partialmethod(optim_cls.__init__, **optim_kwargs)

    return PartialOptimizer


@no_type_check
def partial_scheduler(scheduler_cls: Type[_LRScheduler], **scheduler_kwargs):
    """
    Partially instantiates a learning rate scheduler class. This approach is
    preferred over :py:func:`functools.partial` since the returned value is a
    scheduler class whose attributes can be inspected and which can be further
    instantiated.

    Args:
        scheduler_cls: Pytorch scheduler class to be partially instantiated.
        **scheduler_kwargs: Keyword arguments for scheduler hyperparemeters.
    """

    class PartialScheduler(scheduler_cls):
        __init__ = functools.partialmethod(scheduler_cls.__init__, **scheduler_kwargs)

    return PartialScheduler


# -----------------------------------------------------------------------------
# Implementation of ExtraOptimizers contains minor edits on source code from:
# https://github.com/GauthierGidel/Variational-Inequality-GAN/blob/master/optim/extragradient.py
# -----------------------------------------------------------------------------

#  MIT License

# Copyright (c) Facebook, Inc. and its affiliates.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# written by Hugo Berard (berard.hugo@gmail.com) while at Facebook.


[docs]class ExtragradientOptimizer(torch.optim.Optimizer): """Base class for optimizers with extrapolation step. Args: params: an iterable of :class:`torch.Tensor`\\s or :class:`dict`\\s. Specifies what Tensors should be optimized. defaults: a dict containing default values of optimization options (used when a parameter group doesn't specify them). """ def __init__(self, params: Iterable, defaults: dict): super(ExtragradientOptimizer, self).__init__(params, defaults) self.params_copy: List[torch.nn.Parameter] = [] def update(self, p, group): raise NotImplementedError
[docs] def extrapolation(self): """ Performs the extrapolation step and saves a copy of the current parameters for the update step. """ # Check if a copy of the parameters was already made. is_empty = len(self.params_copy) == 0 for group in self.param_groups: for p in group["params"]: u = self.update(p, group) if is_empty: # Save the current parameters for the update step. Several # extrapolation step can be made before each update but only # the parameters before the first extrapolation step are # saved. self.params_copy.append(p.data.clone()) if u is None: continue # Update the current parameters p.data.add_(u)
[docs] def step(self, closure: Callable = None): """Performs a single optimization step. Args: closure: A closure that reevaluates the model and returns the loss. """ if len(self.params_copy) == 0: raise RuntimeError("Need to call extrapolation before calling step.") loss = None if closure is not None: loss = closure() i = -1 for group in self.param_groups: for p in group["params"]: i += 1 u = self.update(p, group) if u is None: continue # Update the parameters saved during the extrapolation step p.data = self.params_copy[i].add_(u) # Free the old parameters self.params_copy = [] return loss
[docs]class ExtraSGD(ExtragradientOptimizer): """ Implements stochastic gradient descent with extrapolation step (optionally with momentum). Nesterov momentum is based on the formula from :cite:t:`sutskever2013initialization`. Args: params: Iterable of parameters to optimize or dicts defining parameter groups. lr: Learning rate. momentum: Momentum factor. weight_decay: Weight decay (L2 penalty). dampening: Dampening for momentum. nesterov: If ``True``, enables Nesterov momentum. .. note:: The implementation of SGD with Momentum/Nesterov subtly differs from :cite:t:`sutskever2013initialization`. and implementations in some other frameworks. Considering the specific case of Momentum, the update can be written as .. math:: v = \\rho \\cdot v + g \\\\ p = p - lr \\cdot v where :math:`p`, :math:`v`, :math:`g` and :math:`\\rho` denote the parameters, gradient, velocity, and momentum respectively. This is in contrast to :cite:t:`sutskever2013initialization` and other frameworks which employ an update of the form .. math:: v &= \\rho \\cdot v + lr \\cdot g \\\\ p &= p - v The Nesterov version is analogously modified. """ def __init__( self, params: Iterable, lr: float, momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, ): if lr is None or lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict( lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov, ) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super(ExtraSGD, self).__init__(params, defaults) def __setstate__(self, state): super(torch.optim.SGD, self).__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) def update(self, p, group): weight_decay = group["weight_decay"] momentum = group["momentum"] dampening = group["dampening"] nesterov = group["nesterov"] if p.grad is None: return None d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if "momentum_buffer" not in param_state: buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf return -group["lr"] * d_p
[docs]class ExtraAdam(ExtragradientOptimizer): """Implements the Adam algorithm with an extrapolation step. Args: params: Iterable of parameters to optimize or dicts defining parameter groups. lr : Learning rate. betas: Coefficients used for computing running averages of gradient and its square. eps : Term added to the denominator to improve numerical stability. weight_decay: Weight decay (L2 penalty). amsgrad: Flag to use the AMSGrad variant of this algorithm from :cite:t:`reddi2018amsgrad`. """ def __init__( self, params: Iterable, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, amsgrad: bool = False, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad ) super(ExtraAdam, self).__init__(params, defaults) def __setstate__(self, state): super(ExtraAdam, self).__setstate__(state) for group in self.param_groups: group.setdefault("amsgrad", False) def update(self, p, group): if p.grad is None: return None grad = p.grad.data if grad.is_sparse: raise RuntimeError( "Adam does not support sparse gradients, please consider SparseAdam instead" ) amsgrad = group["amsgrad"] state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] if amsgrad: max_exp_avg_sq = state["max_exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 if group["weight_decay"] != 0: grad = grad.add(group["weight_decay"], p.data) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = max_exp_avg_sq.sqrt().add_(group["eps"]) else: denom = exp_avg_sq.sqrt().add_(group["eps"]) bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 return -step_size * exp_avg / denom