Training a logistic regression classifier on MNIST under a norm constraint.
Note
This example illustrates how to use Cooper on a simple machine learning problem that involves using mini-batches of data.
In this example, we consider a simple convex constrained optimization problem: training a Logistic Regression classifier on the MNIST dataset. The model is constrained so that the squared L2 norm of its parameters is less than 1.
Although we employ a simple Logistic Regression model, the same principles can be applied
This example illustrates how Cooper integrates with a typical PyTorch training pipeline, where:
models are defined using a
torch.nn.Module,steps loop over mini-batches of data,
CUDA acceleration is used.
%%capture
%pip install cooper-optim
import os
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms
import cooper
np.random.seed(0)
torch.manual_seed(0)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class NormConstrainedLogisticRegression(cooper.ConstrainedMinimizationProblem):
def __init__(self, constraint_level: float = 1.0):
super().__init__()
self.constraint_level = constraint_level
# The multiplier must be on the same device as the model
multiplier = cooper.multipliers.DenseMultiplier(num_constraints=1, device=DEVICE)
self.norm_constraint = cooper.Constraint(
constraint_type=cooper.ConstraintType.INEQUALITY,
formulation_type=cooper.formulations.Lagrangian,
multiplier=multiplier,
)
def compute_cmp_state(self, model: torch.nn.Module, inputs: torch.Tensor, targets: torch.Tensor) -> cooper.CMPState:
logits = model(inputs.view(inputs.shape[0], -1))
loss = torch.nn.functional.cross_entropy(logits, targets)
accuracy = (logits.argmax(dim=1) == targets).float().mean()
sq_l2_norm = model.weight.pow(2).sum() + model.bias.pow(2).sum()
# Constraint violations use convention g <= 0
constraint_state = cooper.ConstraintState(violation=sq_l2_norm - self.constraint_level)
# Create a CMPState object, which contains the loss and observed constraints
observed_constraints = {self.norm_constraint: constraint_state}
return cooper.CMPState(loss=loss, observed_constraints=observed_constraints, misc={"accuracy": accuracy})
# Load the MNIST dataset
data_path = "./data"
data_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset = datasets.MNIST(data_path, train=True, download=True, transform=data_transforms)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=2, pin_memory=torch.cuda.is_available())
# Create a Logistic Regression model
model = torch.nn.Linear(in_features=28 * 28, out_features=10, bias=True)
model = model.to(DEVICE)
# Instantiate the constrained optimization problem
cmp = NormConstrainedLogisticRegression(constraint_level=1.0)
# Instantiate a PyTorch optimizer for the primal variables
primal_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, amsgrad=True)
# Instantiate PyTorch optimizer for the dual variables
dual_optimizer = torch.optim.SGD(cmp.dual_parameters(), lr=1e-3, maximize=True)
# Instantiate the Cooper optimizer
cooper_optimizer = cooper.optim.SimultaneousOptimizer(
primal_optimizers=primal_optimizer, dual_optimizers=dual_optimizer, cmp=cmp
)
# Create a directory to save checkpoints
checkpoint_path = "./checkpoint"
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
# Load checkpoint if exists
if not os.path.isfile(checkpoint_path + "/checkpoint.pth"):
batch_ix = 0
start_epoch = 0
all_metrics = defaultdict(list)
else:
checkpoint = torch.load(checkpoint_path + "/checkpoint.pth", weights_only=True)
batch_ix = checkpoint["batch_ix"]
start_epoch = checkpoint["epoch"] + 1
all_metrics = checkpoint["all_metrics"]
model.load_state_dict(checkpoint["model_state_dict"])
cmp.load_state_dict(checkpoint["cmp_state_dict"])
cooper_optimizer.load_state_dict(checkpoint["cooper_optimizer_state_dict"])
for epoch_num in range(start_epoch, 7):
for inputs, targets in train_loader:
batch_ix += 1
if torch.cuda.is_available():
inputs, targets = inputs.cuda(non_blocking=True), targets.cuda(non_blocking=True)
_, cmp_state, primal_lagrangian_store, _ = cooper_optimizer.roll(
compute_cmp_state_kwargs={"model": model, "inputs": inputs, "targets": targets}
)
if batch_ix % 3 == 0:
all_metrics["batch_ix"].append(batch_ix)
all_metrics["train_loss"].append(cmp_state.loss.item())
all_metrics["train_acc"].append(cmp_state.misc["accuracy"].item())
multiplier_value = primal_lagrangian_store.multiplier_values[cmp.norm_constraint].item()
all_metrics["multiplier_value"].append(multiplier_value)
constraint_violation = cmp_state.observed_constraints[cmp.norm_constraint].violation
all_metrics["constraint_violation"].append(constraint_violation.item())
# Save checkpoint at the end of each epoch
torch.save(
{
"batch_ix": batch_ix,
"epoch": epoch_num,
"all_metrics": all_metrics,
"model_state_dict": model.state_dict(),
"cmp_state_dict": cmp.state_dict(),
"cooper_optimizer_state_dict": cooper_optimizer.state_dict(),
},
checkpoint_path + "/checkpoint.pth",
)
del batch_ix, all_metrics, model, cmp, cooper_optimizer
# Post-training analysis and plotting
all_metrics = torch.load(checkpoint_path + "/checkpoint.pth", weights_only=True)["all_metrics"]
fig, (ax0, ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=4, sharex=True, figsize=(18, 4))
ax0.plot(all_metrics["batch_ix"], all_metrics["train_loss"])
ax0.set_xlabel("Batch")
ax0.set_title("Training Loss")
ax1.plot(all_metrics["batch_ix"], all_metrics["train_acc"])
ax1.set_xlabel("Batch")
ax1.set_title("Training Acc")
ax2.plot(all_metrics["batch_ix"], np.stack(all_metrics["multiplier_value"]))
ax2.set_xlabel("Batch")
ax2.set_title("Inequality Multiplier")
ax3.plot(all_metrics["batch_ix"], np.stack(all_metrics["constraint_violation"]))
# Show that defect converges close to zero
ax3.axhline(0.0, c="gray", alpha=0.5)
ax3.set_xlabel("Batch")
ax3.set_title("Inequality Defect")
plt.show()
0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
100.0%
2.0%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%
100.0%
---------------------------------------------------------------------------
UnpicklingError Traceback (most recent call last)
Cell In[2], line 128
125 del batch_ix, all_metrics, model, cmp, cooper_optimizer
127 # Post-training analysis and plotting
--> 128 all_metrics = torch.load(checkpoint_path + "/checkpoint.pth", weights_only=True)["all_metrics"]
130 fig, (ax0, ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=4, sharex=True, figsize=(18, 4))
132 ax0.plot(all_metrics["batch_ix"], all_metrics["train_loss"])
File ~/checkouts/readthedocs.org/user_builds/cooper/envs/latest/lib/python3.10/site-packages/torch/serialization.py:1470, in load(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)
1462 return _load(
1463 opened_zipfile,
1464 map_location,
(...)
1467 **pickle_load_args,
1468 )
1469 except pickle.UnpicklingError as e:
-> 1470 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1471 return _load(
1472 opened_zipfile,
1473 map_location,
(...)
1476 **pickle_load_args,
1477 )
1478 if mmap:
UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL collections.defaultdict was not an allowed global by default. Please use `torch.serialization.add_safe_globals([defaultdict])` or the `torch.serialization.safe_globals([defaultdict])` context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.