Note
Click here to download the full example code
Linear classification with rate constraints
In this example we consider a linear classification problem on synthetically mixture of Gaussians data. We constraint the model to predict at least 70% of the training points as blue.
Note that this is a non-differentiable constraint, and thus the typical Lagrangian approach is not applicable as it requires to compute the derivatives of the constraints \(g\) and \(h\).
A commonly used approach to deal with this difficulty is to retain the Lagrangian formulation, but replace the constraints with differentiable approximations or surrogates. However, changing the constraint functions can result in an over- or under-constrained version of the problem.
Cotter et al. [2019] propose a proxy-Lagrangian formulation, in which the non-differentiable constraints are relaxed only when necessary. In other words, the non differentiable constraint functions are used to compute the Lagrangian and constraint violations (and thus the multiplier updates), while the surrogates are used to compute the primal gradients.
This example is based on Fig. 2 of Cotter et al. [2019]. Here we present the naive setting where proxy-constraints are used on a Lagrangian formulation, rather than a “proper” proxy-Lagrangian formulation. For details on the notion of a proxy-Lagrangian formulation, see \(\S\) 4.2 in [Cotter et al., 2019].
31 import random
32
33 import matplotlib.pyplot as plt
34 import numpy as np
35 import style_utils
36 import torch
37 from style_utils import *
38 from torch.nn.functional import binary_cross_entropy_with_logits as bce_loss
39
40 import cooper
41
42 torch.manual_seed(0)
43 np.random.seed(0)
44 random.seed(0)
45
46
47 def generate_mog_dataset():
48 """
49 Generate a MoG dataset on 2D, with two classes.
50 """
51
52 n_per_class = 100
53 dim = 2
54 n_gaussians = 4
55 mus = [(0, 1), (-1, 0), (0, -1), (1, 0)]
56 mus = [torch.tensor(m) for m in mus]
57 var = 0.05
58
59 inputs, labels = [], []
60
61 for id in range(n_gaussians):
62 # Generate input data by mu + x @ sqrt(cov)
63 cov = np.sqrt(var) * torch.eye(dim)
64 mu = mus[id]
65 inputs.append(mu + torch.randn(n_per_class, dim) @ cov)
66
67 # Labels
68 labels.append(torch.tensor(n_per_class * [1.0 if id < 2 else 0.0]))
69
70 return torch.cat(inputs, dim=0), torch.cat(labels, dim=0)
71
72
73 def plot_pane(ax, inputs, x1, x2, achieved_const, titles, colors):
74 const_str = str(np.round(achieved_const, 0)) + "%"
75 ax.scatter(*torch.transpose(inputs, 0, 1), color=colors)
76 ax.plot(x1, x2, color="gray", linestyle="--")
77 ax.fill_between(x1, -2, x2, color=blue, alpha=0.1)
78 ax.fill_between(x1, x2, 2, color=red, alpha=0.1)
79
80 ax.set_aspect("equal")
81 ax.set_xlim(-2, 2)
82 ax.set_ylim(-2, 2)
83 ax.set_title(titles[idx] + " - Pred. Blue Prop.: " + const_str)
84
85
86 class MixtureSeparation(cooper.ConstrainedMinimizationProblem):
87 """
88 Implements CMP for separating the MoG dataset with a linear predictor.
89
90 Args:
91 is_constrained: Flag to apply or not the constraint on the percentage of
92 points predicted as belonging to the blue class
93 use_proxy: Flag to use proxy-constraints. If ``True``, we use a hinge
94 relaxation. Defaults to ``False``.
95 const_level: Minimum proportion of points to be predicted as belonging
96 to the blue class. Ignored when ``is_constrained==False``. Defaults
97 to ``0.7``.
98 """
99
100 def __init__(
101 self, is_constrained: bool, use_proxy: bool = False, const_level: float = 0.7
102 ):
103
104 super().__init__()
105
106 self.is_constrained = is_constrained
107 self.const_level = const_level
108 self.use_proxy = use_proxy
109
110 def closure(self, model, inputs, targets):
111
112 logits = model(inputs)
113 loss = bce_loss(logits.flatten(), targets)
114
115 if not self.is_constrained:
116 # Unconstrained problem of separating two classes
117 state = cooper.CMPState(
118 loss=loss,
119 )
120 else:
121 # Separating classes s.t. predicting at least const_level as class 0
122
123 # Hinge approximation of the rate
124 probs = torch.sigmoid(logits)
125 hinge = torch.mean(torch.max(torch.zeros_like(probs), 1 - probs))
126
127 # level - proxy_ineq_defect <= 0
128 hinge_defect = self.const_level - hinge
129
130 if not self.use_proxy:
131 ineq_defect = hinge_defect
132 proxy_ineq_defect = None
133 else:
134 # Use non-proxy defects to update the Lagrange multipliers
135
136 # Proportion of elements in class 0 is the non-proxy defect
137 classes = logits >= 0.0
138 prop_0 = torch.sum(classes == 0) / targets.numel()
139 ineq_defect = self.const_level - prop_0
140 proxy_ineq_defect = hinge_defect
141
142 state = cooper.CMPState(
143 loss=loss,
144 ineq_defect=ineq_defect,
145 proxy_ineq_defect=proxy_ineq_defect,
146 )
147
148 return state
149
150
151 def train(problem_name, inputs, targets, num_iters=5000, lr=1e-2, const_level=0.7):
152 """
153 Train via SGD
154 """
155
156 is_constrained = problem_name.lower() in ["constrained", "proxy"]
157 use_proxy = problem_name.lower() == "proxy"
158
159 model = torch.nn.Linear(2, 1)
160 primal_optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.7)
161
162 cmp = MixtureSeparation(is_constrained, use_proxy, const_level)
163
164 if is_constrained:
165
166 formulation = cooper.LagrangianFormulation(cmp)
167
168 dual_optimizer = cooper.optim.partial_optimizer(
169 torch.optim.SGD, lr=lr, momentum=0.7
170 )
171
172 constrained_optimizer = cooper.SimultaneousConstrainedOptimizer(
173 formulation=formulation,
174 primal_optimizers=primal_optimizer,
175 dual_optimizer=dual_optimizer,
176 )
177 else:
178
179 formulation = cooper.UnconstrainedFormulation(cmp)
180 dual_optimizer = None
181
182 constrained_optimizer = cooper.UnconstrainedOptimizer(
183 formulation=formulation, primal_optimizers=primal_optimizer
184 )
185
186 for i in range(num_iters):
187 constrained_optimizer.zero_grad()
188 if is_constrained:
189 lagrangian = formulation.compute_lagrangian(
190 cmp.closure, model, inputs, targets
191 )
192 formulation.backward(lagrangian)
193 else:
194 # No Lagrangian in the unconstrained case
195 loss = cmp.closure(model, inputs, targets).loss
196 loss.backward()
197
198 constrained_optimizer.step()
199
200 # Number of elements predicted as class 0 in the train set after training
201 logits = model(inputs)
202 pred_classes = logits >= 0.0
203 prop_0 = torch.sum(pred_classes == 0) / targets.numel()
204
205 return model, 100 * prop_0.item()
206
207
208 # Plot configs
209 titles = ["Unconstrained", "Constrained", "Proxy"]
210 fig, axs = plt.subplots(1, 3, figsize=(18, 6))
211
212 # Data and training configs
213 inputs, labels = generate_mog_dataset()
214 const_level = 0.7
215 lr = 2e-2
216 num_iters = 5000
217
218 for idx, name in enumerate(titles):
219
220 model, achieved_const = train(
221 name, inputs, labels, lr=lr, num_iters=num_iters, const_level=const_level
222 )
223
224 # Compute decision boundary
225 weight, bias = model.weight.data.flatten().numpy(), model.bias.data.numpy()
226 x1 = np.linspace(-2, 2, 100)
227 x2 = (-1 / weight[1]) * (weight[0] * x1 + bias)
228
229 # Color points according to true label
230 red, blue = style_utils.COLOR_DICT["red"], style_utils.COLOR_DICT["blue"]
231 colors = [red if _ == 1 else blue for _ in labels.flatten()]
232 plot_pane(axs[idx], inputs, x1, x2, achieved_const, titles, colors)
233
234 fig.suptitle("Goal: Predict at least " + str(const_level * 100) + "% as blue")
235 plt.show()
Total running time of the script: ( 0 minutes 11.023 seconds)