Training a logistic regression classifier on MNIST under a norm constraint

Here we consider a simple convex constrained optimization problem that involves training a Logistic Regression clasifier on the MNIST dataset. The model is constrained so that the squared L2 norm of its parameters is less than 1.

This example illustrates how Cooper integrates with:
  • constructing a cooper.LagrangianFormulation and a cooper.SimultaneousConstrainedOptimizer

  • models defined using a torch.nn.Module,

  • CUDA acceleration,

  • typical machine learning training loops,

  • extracting the value of the Lagrange multipliers from a cooper.LagrangianFormulation.

Training Loss, Training Acc, Inequality Multiplier, Inequality Defect
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz

0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz

100.0%
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz

2.0%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz

100.0%
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

/home/docs/checkouts/readthedocs.org/user_builds/cooper/envs/dev/lib/python3.8/site-packages/torch/utils/data/dataloader.py:554: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(

 18 import matplotlib.pyplot as plt
 19 import numpy as np
 20 import torch
 21 from style_utils import *
 22 from torchvision import datasets, transforms
 23
 24 import cooper
 25
 26 np.random.seed(0)
 27 torch.manual_seed(0)
 28
 29 data_transforms = transforms.Compose(
 30     [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
 31 )
 32 train_loader = torch.utils.data.DataLoader(
 33     datasets.MNIST("data", train=True, download=True, transform=data_transforms),
 34     batch_size=256,
 35     num_workers=4,
 36     pin_memory=torch.cuda.is_available(),
 37 )
 38
 39 loss_fn = torch.nn.CrossEntropyLoss()
 40
 41 # Create a Logistic Regression model
 42 model = torch.nn.Linear(in_features=28 * 28, out_features=10, bias=True)
 43 if torch.cuda.is_available():
 44     model = model.cuda()
 45 primal_optimizer = torch.optim.Adagrad(model.parameters(), lr=5e-3)
 46
 47 # Create a Cooper formulation, and pick a Pytorch optimizer class for the dual variables
 48 formulation = cooper.LagrangianFormulation()
 49 dual_optimizer = cooper.optim.partial_optimizer(torch.optim.SGD, lr=1e-3)
 50
 51 # Create a ConstrainedOptimizer for performing simultaneous updates based on the
 52 # formulation, and the selected primal and dual optimizers.
 53 cooper_optimizer = cooper.SimultaneousConstrainedOptimizer(
 54     formulation, primal_optimizer, dual_optimizer
 55 )
 56
 57 all_metrics = {
 58     "batch_ix": [],
 59     "train_loss": [],
 60     "train_acc": [],
 61     "ineq_multiplier": [],
 62     "ineq_defect": [],
 63 }
 64
 65 batch_ix = 0
 66
 67 for epoch_num in range(7):
 68
 69     for inputs, targets in train_loader:
 70         batch_ix += 1
 71
 72         if torch.cuda.is_available():
 73             inputs, targets = inputs.cuda(), targets.cuda()
 74
 75         logits = model.forward(inputs.view(inputs.shape[0], -1))
 76         loss = loss_fn(logits, targets)
 77         accuracy = (logits.argmax(dim=1) == targets).float().mean()
 78
 79         sq_l2_norm = model.weight.pow(2).sum() + model.bias.pow(2).sum()
 80         # Constraint defects use convention “g - \epsilon ≤ 0”
 81         constraint_defect = sq_l2_norm - 1.0
 82
 83         # Create a CMPState object, which contains the loss and constraint defect
 84         cmp_state = cooper.CMPState(loss=loss, ineq_defect=constraint_defect)
 85
 86         cooper_optimizer.zero_grad()
 87         lagrangian = formulation.compute_lagrangian(pre_computed_state=cmp_state)
 88         formulation.backward(lagrangian)
 89         cooper_optimizer.step()
 90
 91         # Extract the value of the Lagrange multiplier associated with the constraint
 92         # The dual variables are stored and updated internally by Cooper
 93         lag_multiplier, _ = formulation.state()
 94
 95         if batch_ix % 3 == 0:
 96             all_metrics["batch_ix"].append(batch_ix)
 97             all_metrics["train_loss"].append(loss.item())
 98             all_metrics["train_acc"].append(accuracy.item())
 99             all_metrics["ineq_multiplier"].append(lag_multiplier.item())
100             all_metrics["ineq_defect"].append(constraint_defect.item())
101
102 fig, (ax0, ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=4, sharex=True, figsize=(18, 4))
103
104 ax0.plot(all_metrics["batch_ix"], all_metrics["train_loss"])
105 ax0.set_xlabel("Batch")
106 ax0.set_title("Training Loss")
107
108 ax1.plot(all_metrics["batch_ix"], all_metrics["train_acc"])
109 ax1.set_xlabel("Batch")
110 ax1.set_title("Training Acc")
111
112 ax2.plot(all_metrics["batch_ix"], np.stack(all_metrics["ineq_multiplier"]))
113 ax2.set_xlabel("Batch")
114 ax2.set_title("Inequality Multiplier")
115
116 ax3.plot(all_metrics["batch_ix"], np.stack(all_metrics["ineq_defect"]))
117 # Show that defect converges close to zero
118 ax3.axhline(0.0, c="gray", alpha=0.5)
119 ax3.set_xlabel("Batch")
120 ax3.set_title("Inequality Defect")
121
122 plt.show()

Total running time of the script: ( 1 minutes 13.122 seconds)

Gallery generated by Sphinx-Gallery