Linear classification with rate constraints

In this example we consider a linear classification problem on synthetically mixture of Gaussians data. We constraint the model to predict at least 70% of the training points as blue.

Note that this is a non-differentiable constraint, and thus the typical Lagrangian approach is not applicable as it requires to compute the derivatives of the constraints \(g\) and \(h\).

A commonly used approach to deal with this difficulty is to retain the Lagrangian formulation, but replace the constraints with differentiable approximations or surrogates. However, changing the constraint functions can result in an over- or under-constrained version of the problem.

Cotter et al. [2019] propose a proxy-Lagrangian formulation, in which the non-differentiable constraints are relaxed only when necessary. In other words, the non differentiable constraint functions are used to compute the Lagrangian and constraint violations (and thus the multiplier updates), while the surrogates are used to compute the primal gradients.

This example is based on Fig. 2 of Cotter et al. [2019]. Here we present the naive setting where proxy-constraints are used on a Lagrangian formulation, rather than a “proper” proxy-Lagrangian formulation. For details on the notion of a proxy-Lagrangian formulation, see \(\S\) 4.2 in [Cotter et al., 2019].

Goal: Predict at least 70.0% as blue, Unconstrained - Pred. Blue Prop.: 50.0%, Constrained - Pred. Blue Prop.: 66.0%, Proxy - Pred. Blue Prop.: 70.0%
 31 import random
 32
 33 import matplotlib.pyplot as plt
 34 import numpy as np
 35 import style_utils
 36 import torch
 37 from style_utils import *
 38 from torch.nn.functional import binary_cross_entropy_with_logits as bce_loss
 39
 40 import cooper
 41
 42 torch.manual_seed(0)
 43 np.random.seed(0)
 44 random.seed(0)
 45
 46
 47 def generate_mog_dataset():
 48     """
 49     Generate a MoG dataset on 2D, with two classes.
 50     """
 51
 52     n_per_class = 100
 53     dim = 2
 54     n_gaussians = 4
 55     mus = [(0, 1), (-1, 0), (0, -1), (1, 0)]
 56     mus = [torch.tensor(m) for m in mus]
 57     var = 0.05
 58
 59     inputs, labels = [], []
 60
 61     for id in range(n_gaussians):
 62         # Generate input data by mu + x @ sqrt(cov)
 63         cov = np.sqrt(var) * torch.eye(dim)
 64         mu = mus[id]
 65         inputs.append(mu + torch.randn(n_per_class, dim) @ cov)
 66
 67         # Labels
 68         labels.append(torch.tensor(n_per_class * [1.0 if id < 2 else 0.0]))
 69
 70     return torch.cat(inputs, dim=0), torch.cat(labels, dim=0)
 71
 72
 73 def plot_pane(ax, inputs, x1, x2, achieved_const, titles, colors):
 74     const_str = str(np.round(achieved_const, 0)) + "%"
 75     ax.scatter(*torch.transpose(inputs, 0, 1), color=colors)
 76     ax.plot(x1, x2, color="gray", linestyle="--")
 77     ax.fill_between(x1, -2, x2, color=blue, alpha=0.1)
 78     ax.fill_between(x1, x2, 2, color=red, alpha=0.1)
 79
 80     ax.set_aspect("equal")
 81     ax.set_xlim(-2, 2)
 82     ax.set_ylim(-2, 2)
 83     ax.set_title(titles[idx] + " - Pred. Blue Prop.: " + const_str)
 84
 85
 86 class MixtureSeparation(cooper.ConstrainedMinimizationProblem):
 87     """
 88     Implements CMP for separating the MoG dataset with a linear predictor.
 89
 90     Args:
 91         is_constrained: Flag to apply or not the constraint on the percentage of
 92             points predicted as belonging to the blue class
 93         use_proxy: Flag to use proxy-constraints. If ``True``, we use a hinge
 94             relaxation. Defaults to ``False``.
 95         const_level: Minimum proportion of points to be predicted as belonging
 96             to the blue class. Ignored when ``is_constrained==False``. Defaults
 97             to ``0.7``.
 98     """
 99
100     def __init__(
101         self, is_constrained: bool, use_proxy: bool = False, const_level: float = 0.7
102     ):
103
104         super().__init__()
105
106         self.is_constrained = is_constrained
107         self.const_level = const_level
108         self.use_proxy = use_proxy
109
110     def closure(self, model, inputs, targets):
111
112         logits = model(inputs)
113         loss = bce_loss(logits.flatten(), targets)
114
115         if not self.is_constrained:
116             # Unconstrained problem of separating two classes
117             state = cooper.CMPState(
118                 loss=loss,
119             )
120         else:
121             # Separating classes s.t. predicting at least const_level as class 0
122
123             # Hinge approximation of the rate
124             probs = torch.sigmoid(logits)
125             hinge = torch.mean(torch.max(torch.zeros_like(probs), 1 - probs))
126
127             # level - proxy_ineq_defect <= 0
128             hinge_defect = self.const_level - hinge
129
130             if not self.use_proxy:
131                 ineq_defect = hinge_defect
132                 proxy_ineq_defect = None
133             else:
134                 # Use non-proxy defects to update the Lagrange multipliers
135
136                 # Proportion of elements in class 0 is the non-proxy defect
137                 classes = logits >= 0.0
138                 prop_0 = torch.sum(classes == 0) / targets.numel()
139                 ineq_defect = self.const_level - prop_0
140                 proxy_ineq_defect = hinge_defect
141
142             state = cooper.CMPState(
143                 loss=loss,
144                 ineq_defect=ineq_defect,
145                 proxy_ineq_defect=proxy_ineq_defect,
146             )
147
148         return state
149
150
151 def train(problem_name, inputs, targets, num_iters=5000, lr=1e-2, const_level=0.7):
152     """
153     Train via SGD
154     """
155
156     is_constrained = problem_name.lower() in ["constrained", "proxy"]
157     use_proxy = problem_name.lower() == "proxy"
158
159     model = torch.nn.Linear(2, 1)
160     primal_optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.7)
161
162     cmp = MixtureSeparation(is_constrained, use_proxy, const_level)
163
164     if is_constrained:
165
166         formulation = cooper.LagrangianFormulation(cmp)
167
168         dual_optimizer = cooper.optim.partial_optimizer(
169             torch.optim.SGD, lr=lr, momentum=0.7
170         )
171
172         constrained_optimizer = cooper.SimultaneousConstrainedOptimizer(
173             formulation=formulation,
174             primal_optimizers=primal_optimizer,
175             dual_optimizer=dual_optimizer,
176         )
177     else:
178
179         formulation = cooper.UnconstrainedFormulation(cmp)
180         dual_optimizer = None
181
182         constrained_optimizer = cooper.UnconstrainedOptimizer(
183             formulation=formulation, primal_optimizers=primal_optimizer
184         )
185
186     for i in range(num_iters):
187         constrained_optimizer.zero_grad()
188         if is_constrained:
189             lagrangian = formulation.compute_lagrangian(
190                 cmp.closure, model, inputs, targets
191             )
192             formulation.backward(lagrangian)
193         else:
194             # No Lagrangian in the unconstrained case
195             loss = cmp.closure(model, inputs, targets).loss
196             loss.backward()
197
198         constrained_optimizer.step()
199
200     # Number of elements predicted as class 0 in the train set after training
201     logits = model(inputs)
202     pred_classes = logits >= 0.0
203     prop_0 = torch.sum(pred_classes == 0) / targets.numel()
204
205     return model, 100 * prop_0.item()
206
207
208 # Plot configs
209 titles = ["Unconstrained", "Constrained", "Proxy"]
210 fig, axs = plt.subplots(1, 3, figsize=(18, 6))
211
212 # Data and training configs
213 inputs, labels = generate_mog_dataset()
214 const_level = 0.7
215 lr = 2e-2
216 num_iters = 5000
217
218 for idx, name in enumerate(titles):
219
220     model, achieved_const = train(
221         name, inputs, labels, lr=lr, num_iters=num_iters, const_level=const_level
222     )
223
224     # Compute decision boundary
225     weight, bias = model.weight.data.flatten().numpy(), model.bias.data.numpy()
226     x1 = np.linspace(-2, 2, 100)
227     x2 = (-1 / weight[1]) * (weight[0] * x1 + bias)
228
229     # Color points according to true label
230     red, blue = style_utils.COLOR_DICT["red"], style_utils.COLOR_DICT["blue"]
231     colors = [red if _ == 1 else blue for _ in labels.flatten()]
232     plot_pane(axs[idx], inputs, x1, x2, achieved_const, titles, colors)
233
234 fig.suptitle("Goal: Predict at least " + str(const_level * 100) + "% as blue")
235 plt.show()

Total running time of the script: ( 0 minutes 11.023 seconds)

Gallery generated by Sphinx-Gallery