#!/usr/bin/env python
"""Classes for modeling dual variables (e.g. Lagrange multipliers)."""
import abc
import torch
[docs]class BaseMultiplier(torch.nn.Module, metaclass=abc.ABCMeta):
"""
Base class for Lagrange multipliers. This base class can be extended to
different types of multipliers: Dense, Sparse or implicit multipliers.
"""
def __init__(self) -> None:
super().__init__()
@property
@abc.abstractmethod
def shape(self):
"""
Returns the shape of the explicit multipliers. In the case of implicit
multipliers, this should return the *actual* predicted multipliers.
"""
pass
@property
@abc.abstractmethod
def grad(self):
"""
Returns the gradient of trainable parameters associated with the
multipliers. In the case of implicit multipliers, this corresponds to
the gradient with respect to the parameters of the model which predicts
the multiplier values.
"""
pass
[docs] @abc.abstractmethod
def forward(self):
"""
Returns the *actual* value of the multipliers. When using implicit
multipliers, the signature of this method may be change to enable
passing the "features" of the constraint to predict the corresponding
multiplier.
"""
pass
[docs] @abc.abstractmethod
def project_(self):
"""
In-place projection function for multipliers.
"""
pass
[docs]class DenseMultiplier(BaseMultiplier):
"""
A dense multiplier. Holds a :py:class:`~torch.nn.parameter.Parameter`,
which contains the value of the Lagrange multipliers associated with the
equality or inequality constraints of a
:py:class:`~cooper.problem.ConstrainedMinimizationProblem`.
Args:
init: Initial value of the multiplier.
positive: Whether to enforce non-negativity on the values of the
multiplier.
"""
def __init__(self, init: torch.Tensor, *, positive: bool = False):
super().__init__()
self.weight = torch.nn.Parameter(init)
self.positive = positive
@property
def shape(self):
"""Returns the shape of the multiplier tensor."""
return self.weight.shape
@property
def grad(self):
"""Returns current gradient stored in the multiplier tensor."""
return self.weight.grad
[docs] def forward(self):
"""Return the current value of the multiplier."""
return self.weight
[docs] def project_(self):
"""
Ensures multipliers associated with inequality constraints reamain
non-negative.
"""
if self.positive:
self.weight.data = torch.relu(self.weight.data)
def __str__(self):
return str(self.weight.data)
def __repr__(self):
pos_str = "inequality" if self.positive else "equality"
rep = "DenseMultiplier(" + pos_str + ", " + str(self.weight.data) + ")"
return rep