Linear classification with rate constraints

In this example we consider a linear classification problem on synthetically mixture of Gaussians data. We constraint the model to predict at least 70% of the training points as blue.

Note that this is a non-differentiable constraint, and thus the typical Lagrangian approach is not applicable as it requires to compute the derivatives of the constraints \(g\) and \(h\).

A commonly used approach to deal with this difficulty is to retain the Lagrangian formulation, but replace the constraints with differentiable approximations or surrogates. However, changing the constraint functions can result in an over- or under-constrained version of the problem.

Cotter et al. [2019] propose a proxy-Lagrangian formulation, in which the non-differentiable constraints are relaxed only when necessary. In other words, the non differentiable constraint functions are used to compute the Lagrangian and constraint violations (and thus the multiplier updates), while the surrogates are used to compute the primal gradients.

This example is based on Fig. 2 of Cotter et al. [2019]. Here we present the naive setting where proxy-constraints are used on a Lagrangian formulation, rather than a “proper” proxy-Lagrangian formulation. For details on the notion of a proxy-Lagrangian formulation, see \(\S\) 4.2 in [Cotter et al., 2019].

Goal: Predict at least 70.0% as blue, Unconstrained - Pred. Blue Prop.: 50.0%, Constrained - Pred. Blue Prop.: 66.0%, Proxy - Pred. Blue Prop.: 70.0%
 31 import random
 32
 33 import matplotlib.pyplot as plt
 34 import numpy as np
 35 import style_utils
 36 import torch
 37 from torch.nn.functional import binary_cross_entropy_with_logits as bce_loss
 38
 39 import cooper
 40
 41 torch.manual_seed(0)
 42 np.random.seed(0)
 43 random.seed(0)
 44
 45
 46 def generate_mog_dataset():
 47     """
 48     Generate a MoG dataset on 2D, with two classes.
 49     """
 50
 51     n_per_class = 100
 52     dim = 2
 53     n_gaussians = 4
 54     mus = [(0, 1), (-1, 0), (0, -1), (1, 0)]
 55     mus = [torch.tensor(m) for m in mus]
 56     var = 0.05
 57
 58     inputs, labels = [], []
 59
 60     for id in range(n_gaussians):
 61         # Generate input data by mu + x @ sqrt(cov)
 62         cov = np.sqrt(var) * torch.eye(dim)
 63         mu = mus[id]
 64         inputs.append(mu + torch.randn(n_per_class, dim) @ cov)
 65
 66         # Labels
 67         labels.append(torch.tensor(n_per_class * [1.0 if id < 2 else 0.0]))
 68
 69     return torch.cat(inputs, dim=0), torch.cat(labels, dim=0)
 70
 71
 72 def plot_pane(ax, inputs, x1, x2, achieved_const, titles, colors):
 73     const_str = str(np.round(achieved_const, 0)) + "%"
 74     ax.scatter(*torch.transpose(inputs, 0, 1), color=colors)
 75     ax.plot(x1, x2, color="gray", linestyle="--")
 76     ax.fill_between(x1, -2, x2, color=blue, alpha=0.1)
 77     ax.fill_between(x1, x2, 2, color=red, alpha=0.1)
 78
 79     ax.set_aspect("equal")
 80     ax.set_xlim(-2, 2)
 81     ax.set_ylim(-2, 2)
 82     ax.set_title(titles[idx] + " - Pred. Blue Prop.: " + const_str)
 83
 84
 85 class MixtureSeparation(cooper.ConstrainedMinimizationProblem):
 86     """
 87     Implements CMP for separating the MoG dataset with a linear predictor.
 88
 89     Args:
 90         is_constrained: Flag to apply or not the constraint on the percentage of
 91             points predicted as belonging to the blue class
 92         use_proxy: Flag to use proxy-constraints. If ``True``, we use a hinge
 93             relaxation. Defaults to ``False``.
 94         const_level: Minimum proportion of points to be predicted as belonging
 95             to the blue class. Ignored when ``is_constrained==False``. Defaults
 96             to ``0.7``.
 97     """
 98
 99     def __init__(
100         self, is_constrained: bool, use_proxy: bool = False, const_level: float = 0.7
101     ):
102
103         super().__init__(is_constrained=is_constrained)
104
105         self.const_level = const_level
106         self.use_proxy = use_proxy
107
108     def closure(self, model, inputs, targets):
109
110         logits = model(inputs)
111         loss = bce_loss(logits.flatten(), targets)
112
113         if not self.is_constrained:
114             # Unconstrained problem of separating two classes
115             state = cooper.CMPState(
116                 loss=loss,
117             )
118         else:
119             # Separating classes s.t. predicting at least const_level as class 0
120
121             # Hinge approximation of the rate
122             probs = torch.sigmoid(logits)
123             hinge = torch.mean(torch.max(torch.zeros_like(probs), 1 - probs))
124
125             # level - proxy_ineq_defect <= 0
126             hinge_defect = self.const_level - hinge
127
128             if not self.use_proxy:
129                 ineq_defect = hinge_defect
130                 proxy_ineq_defect = None
131             else:
132                 # Use non-proxy defects to update the Lagrange multipliers
133
134                 # Proportion of elements in class 0 is the non-proxy defect
135                 classes = logits >= 0.0
136                 prop_0 = torch.sum(classes == 0) / targets.numel()
137                 ineq_defect = self.const_level - prop_0
138                 proxy_ineq_defect = hinge_defect
139
140             state = cooper.CMPState(
141                 loss=loss,
142                 ineq_defect=ineq_defect,
143                 proxy_ineq_defect=proxy_ineq_defect,
144             )
145
146         return state
147
148
149 def train(problem_name, inputs, targets, num_iters=5000, lr=1e-2, const_level=0.7):
150     """
151     Train via SGD
152     """
153
154     is_constrained = problem_name.lower() in ["constrained", "proxy"]
155     use_proxy = problem_name.lower() == "proxy"
156
157     model = torch.nn.Linear(2, 1)
158
159     cmp = MixtureSeparation(is_constrained, use_proxy, const_level)
160     formulation = cooper.LagrangianFormulation(cmp)
161
162     primal_optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.7)
163
164     if is_constrained:
165         dual_optimizer = cooper.optim.partial_optimizer(torch.optim.SGD, lr=lr, momentum=0.7)
166     else:
167         dual_optimizer = None
168
169     constrained_optimizer = cooper.ConstrainedOptimizer(
170         formulation=formulation,
171         primal_optimizer=primal_optimizer,
172         dual_optimizer=dual_optimizer,
173     )
174
175     for i in range(num_iters):
176         constrained_optimizer.zero_grad()
177         if is_constrained:
178             lagrangian = formulation.composite_objective(
179                 cmp.closure, model, inputs, targets
180             )
181             formulation.custom_backward(lagrangian)
182         else:
183             # No Lagrangian in the unconstrained case
184             loss = cmp.closure(model, inputs, targets).loss
185             loss.backward()
186
187         constrained_optimizer.step()
188
189     # Number of elements predicted as class 0 in the train set after training
190     logits = model(inputs)
191     pred_classes = logits >= 0.0
192     prop_0 = torch.sum(pred_classes == 0) / targets.numel()
193
194     return model, 100 * prop_0.item()
195
196
197 # Plot configs
198 titles = ["Unconstrained", "Constrained", "Proxy"]
199 fig, axs = plt.subplots(1, 3, figsize=(18, 6))
200
201 # Data and training configs
202 inputs, labels = generate_mog_dataset()
203 const_level = 0.7
204 lr = 2e-2
205 num_iters = 5000
206
207 for idx, name in enumerate(titles):
208
209     model, achieved_const = train(
210         name, inputs, labels, lr=lr, num_iters=num_iters, const_level=const_level
211     )
212
213     # Compute decision boundary
214     weight, bias = model.weight.data.flatten().numpy(), model.bias.data.numpy()
215     x1 = np.linspace(-2, 2, 100)
216     x2 = (-1 / weight[1]) * (weight[0] * x1 + bias)
217
218     # Color points according to true label
219     red, blue = style_utils.COLOR_DICT["red"], style_utils.COLOR_DICT["blue"]
220     colors = [red if _ == 1 else blue for _ in labels.flatten()]
221     plot_pane(axs[idx], inputs, x1, x2, achieved_const, titles, colors)
222
223 fig.suptitle("Goal: Predict at least " + str(const_level * 100) + "% as blue")
224 plt.show()

Total running time of the script: ( 0 minutes 11.069 seconds)

Gallery generated by Sphinx-Gallery