Note
Click here to download the full example code
Linear classification with rate constraints
In this example we consider a linear classification problem on synthetically mixture of Gaussians data. We constraint the model to predict at least 70% of the training points as blue.
Note that this is a non-differentiable constraint, and thus the typical Lagrangian approach is not applicable as it requires to compute the derivatives of the constraints \(g\) and \(h\).
A commonly used approach to deal with this difficulty is to retain the Lagrangian formulation, but replace the constraints with differentiable approximations or surrogates. However, changing the constraint functions can result in an over- or under-constrained version of the problem.
Cotter et al. [2019] propose a proxy-Lagrangian formulation, in which the non-differentiable constraints are relaxed only when necessary. In other words, the non differentiable constraint functions are used to compute the Lagrangian and constraint violations (and thus the multiplier updates), while the surrogates are used to compute the primal gradients.
This example is based on Fig. 2 of Cotter et al. [2019]. Here we present the naive setting where proxy-constraints are used on a Lagrangian formulation, rather than a “proper” proxy-Lagrangian formulation. For details on the notion of a proxy-Lagrangian formulation, see \(\S\) 4.2 in [Cotter et al., 2019].
31 import random
32
33 import matplotlib.pyplot as plt
34 import numpy as np
35 import style_utils
36 import torch
37 from torch.nn.functional import binary_cross_entropy_with_logits as bce_loss
38
39 import cooper
40
41 torch.manual_seed(0)
42 np.random.seed(0)
43 random.seed(0)
44
45
46 def generate_mog_dataset():
47 """
48 Generate a MoG dataset on 2D, with two classes.
49 """
50
51 n_per_class = 100
52 dim = 2
53 n_gaussians = 4
54 mus = [(0, 1), (-1, 0), (0, -1), (1, 0)]
55 mus = [torch.tensor(m) for m in mus]
56 var = 0.05
57
58 inputs, labels = [], []
59
60 for id in range(n_gaussians):
61 # Generate input data by mu + x @ sqrt(cov)
62 cov = np.sqrt(var) * torch.eye(dim)
63 mu = mus[id]
64 inputs.append(mu + torch.randn(n_per_class, dim) @ cov)
65
66 # Labels
67 labels.append(torch.tensor(n_per_class * [1.0 if id < 2 else 0.0]))
68
69 return torch.cat(inputs, dim=0), torch.cat(labels, dim=0)
70
71
72 def plot_pane(ax, inputs, x1, x2, achieved_const, titles, colors):
73 const_str = str(np.round(achieved_const, 0)) + "%"
74 ax.scatter(*torch.transpose(inputs, 0, 1), color=colors)
75 ax.plot(x1, x2, color="gray", linestyle="--")
76 ax.fill_between(x1, -2, x2, color=blue, alpha=0.1)
77 ax.fill_between(x1, x2, 2, color=red, alpha=0.1)
78
79 ax.set_aspect("equal")
80 ax.set_xlim(-2, 2)
81 ax.set_ylim(-2, 2)
82 ax.set_title(titles[idx] + " - Pred. Blue Prop.: " + const_str)
83
84
85 class MixtureSeparation(cooper.ConstrainedMinimizationProblem):
86 """
87 Implements CMP for separating the MoG dataset with a linear predictor.
88
89 Args:
90 is_constrained: Flag to apply or not the constraint on the percentage of
91 points predicted as belonging to the blue class
92 use_proxy: Flag to use proxy-constraints. If ``True``, we use a hinge
93 relaxation. Defaults to ``False``.
94 const_level: Minimum proportion of points to be predicted as belonging
95 to the blue class. Ignored when ``is_constrained==False``. Defaults
96 to ``0.7``.
97 """
98
99 def __init__(
100 self, is_constrained: bool, use_proxy: bool = False, const_level: float = 0.7
101 ):
102
103 super().__init__(is_constrained=is_constrained)
104
105 self.const_level = const_level
106 self.use_proxy = use_proxy
107
108 def closure(self, model, inputs, targets):
109
110 logits = model(inputs)
111 loss = bce_loss(logits.flatten(), targets)
112
113 if not self.is_constrained:
114 # Unconstrained problem of separating two classes
115 state = cooper.CMPState(
116 loss=loss,
117 )
118 else:
119 # Separating classes s.t. predicting at least const_level as class 0
120
121 # Hinge approximation of the rate
122 probs = torch.sigmoid(logits)
123 hinge = torch.mean(torch.max(torch.zeros_like(probs), 1 - probs))
124
125 # level - proxy_ineq_defect <= 0
126 hinge_defect = self.const_level - hinge
127
128 if not self.use_proxy:
129 ineq_defect = hinge_defect
130 proxy_ineq_defect = None
131 else:
132 # Use non-proxy defects to update the Lagrange multipliers
133
134 # Proportion of elements in class 0 is the non-proxy defect
135 classes = logits >= 0.0
136 prop_0 = torch.sum(classes == 0) / targets.numel()
137 ineq_defect = self.const_level - prop_0
138 proxy_ineq_defect = hinge_defect
139
140 state = cooper.CMPState(
141 loss=loss,
142 ineq_defect=ineq_defect,
143 proxy_ineq_defect=proxy_ineq_defect,
144 )
145
146 return state
147
148
149 def train(problem_name, inputs, targets, num_iters=5000, lr=1e-2, const_level=0.7):
150 """
151 Train via SGD
152 """
153
154 is_constrained = problem_name.lower() in ["constrained", "proxy"]
155 use_proxy = problem_name.lower() == "proxy"
156
157 model = torch.nn.Linear(2, 1)
158
159 cmp = MixtureSeparation(is_constrained, use_proxy, const_level)
160 formulation = cooper.LagrangianFormulation(cmp)
161
162 primal_optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.7)
163
164 if is_constrained:
165 dual_optimizer = cooper.optim.partial_optimizer(torch.optim.SGD, lr=lr, momentum=0.7)
166 else:
167 dual_optimizer = None
168
169 constrained_optimizer = cooper.ConstrainedOptimizer(
170 formulation=formulation,
171 primal_optimizer=primal_optimizer,
172 dual_optimizer=dual_optimizer,
173 )
174
175 for i in range(num_iters):
176 constrained_optimizer.zero_grad()
177 if is_constrained:
178 lagrangian = formulation.composite_objective(
179 cmp.closure, model, inputs, targets
180 )
181 formulation.custom_backward(lagrangian)
182 else:
183 # No Lagrangian in the unconstrained case
184 loss = cmp.closure(model, inputs, targets).loss
185 loss.backward()
186
187 constrained_optimizer.step()
188
189 # Number of elements predicted as class 0 in the train set after training
190 logits = model(inputs)
191 pred_classes = logits >= 0.0
192 prop_0 = torch.sum(pred_classes == 0) / targets.numel()
193
194 return model, 100 * prop_0.item()
195
196
197 # Plot configs
198 titles = ["Unconstrained", "Constrained", "Proxy"]
199 fig, axs = plt.subplots(1, 3, figsize=(18, 6))
200
201 # Data and training configs
202 inputs, labels = generate_mog_dataset()
203 const_level = 0.7
204 lr = 2e-2
205 num_iters = 5000
206
207 for idx, name in enumerate(titles):
208
209 model, achieved_const = train(
210 name, inputs, labels, lr=lr, num_iters=num_iters, const_level=const_level
211 )
212
213 # Compute decision boundary
214 weight, bias = model.weight.data.flatten().numpy(), model.bias.data.numpy()
215 x1 = np.linspace(-2, 2, 100)
216 x2 = (-1 / weight[1]) * (weight[0] * x1 + bias)
217
218 # Color points according to true label
219 red, blue = style_utils.COLOR_DICT["red"], style_utils.COLOR_DICT["blue"]
220 colors = [red if _ == 1 else blue for _ in labels.flatten()]
221 plot_pane(axs[idx], inputs, x1, x2, achieved_const, titles, colors)
222
223 fig.suptitle("Goal: Predict at least " + str(const_level * 100) + "% as blue")
224 plt.show()
Total running time of the script: ( 0 minutes 11.069 seconds)