import abc
from dataclasses import dataclass
from typing import Optional
import torch
# Formulation, and some other classes below, are heavily inspired by the design
# of the TensorFlow Constrained Optimization (TFCO) library :
# https://github.com/google-research/tensorflow_constrained_optimization
[docs]@dataclass
class CMPState:
"""Represents the "state" of a Constrained Minimization Problem in terms of
the value of its loss and constraint violations/defects.
Args:
loss: Value of the loss or main objective to be minimized :math:`f(x)`
ineq_defect: Violation of the inequality constraints :math:`g(x)`
eq_defect: Violation of the equality constraints :math:`h(x)`
proxy_ineq_defect: Differentiable surrogate for the inequality
constraints as proposed by :cite:t:`cotter2019JMLR`.
proxy_eq_defect: Differentiable surrogate for the equality constraints
as proposed by :cite:t:`cotter2019JMLR`.
misc: Optional additional information to be store along with the state
of the CMP
"""
loss: Optional[torch.Tensor] = None
ineq_defect: Optional[torch.Tensor] = None
eq_defect: Optional[torch.Tensor] = None
proxy_ineq_defect: Optional[torch.Tensor] = None
proxy_eq_defect: Optional[torch.Tensor] = None
misc: Optional[dict] = None
def as_tuple(self) -> tuple:
return (
self.loss,
self.ineq_defect,
self.eq_defect,
self.proxy_ineq_defect,
self.proxy_eq_defect,
self.misc,
)
[docs]class ConstrainedMinimizationProblem(abc.ABC):
"""Base class for constrained minimization problems."""
def __init__(self):
self._state = CMPState()
@property
def state(self) -> CMPState:
return self._state
@state.setter
def state(self, value: CMPState):
self._state = value
[docs] @abc.abstractmethod
def closure(self, *args, **kwargs) -> CMPState:
"""
Computes the state of the CMP based on the current value of the primal
parameters.
The signature of this abstract function may be changed to accommodate
situations that require a model, (mini-batched) inputs/targets, or
other arguments to be passed.
Structuring the CMP class around this closure method, enables the re-use
of shared sections of a computational graph. For example, consider a
case where we want to minimize a model's cross entropy loss subject to
a constraint on the entropy of its predictions. Both of these quantities
depend on the predicted logits (on a minibatch). This closure-centric
design allows flexible problem specifications while avoiding
re-computation.
"""
[docs] def defect_fn(self) -> CMPState:
"""
Computes the constraints of the CMP based on the current value of the
primal parameters. This function returns a
:py:class:`cooper.problem.CMPState` collecting the values of the (proxy
and/or non-proxy) constraints. Note that this returned ``CMPState`` may
have a ``loss`` attribute with value ``None`` since, by design, the loss
is not necessarily computed when evaluating `only` the constraints.
The signature of this "abstract" function may be changed to accommodate
situations that require a model, (mini-batched) inputs/targets, or
other arguments to be passed.
Depending on the problem at hand, the computation of the constraints
can be compartimentalized in a way that is independent of the evaluation
of the loss. Alternatively,
:py:meth:`~.ConstrainedMinimizationProblem.defect_fn` may be used in the
implementation of the :py:meth:`~.ConstrainedMinimizationProblem.closure`
function.
"""
raise NotImplementedError